agent.rs

   1use crate::{
   2    AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
   3    DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
   4    MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
   5    ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
   6};
   7use acp_thread::AgentModelSelector;
   8use agent_client_protocol as acp;
   9use agent_settings::AgentSettings;
  10use anyhow::{Context as _, Result, anyhow};
  11use collections::{HashSet, IndexMap};
  12use fs::Fs;
  13use futures::channel::mpsc;
  14use futures::{StreamExt, future};
  15use gpui::{
  16    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  17};
  18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  19use project::{Project, ProjectItem, ProjectPath, Worktree};
  20use prompt_store::{
  21    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  22};
  23use settings::update_settings_file;
  24use std::any::Any;
  25use std::collections::HashMap;
  26use std::path::Path;
  27use std::rc::Rc;
  28use std::sync::Arc;
  29use util::ResultExt;
  30
  31const RULES_FILE_NAMES: [&'static str; 9] = [
  32    ".rules",
  33    ".cursorrules",
  34    ".windsurfrules",
  35    ".clinerules",
  36    ".github/copilot-instructions.md",
  37    "CLAUDE.md",
  38    "AGENT.md",
  39    "AGENTS.md",
  40    "GEMINI.md",
  41];
  42
  43pub struct RulesLoadingError {
  44    pub message: SharedString,
  45}
  46
  47/// Holds both the internal Thread and the AcpThread for a session
  48struct Session {
  49    /// The internal thread that processes messages
  50    thread: Entity<Thread>,
  51    /// The ACP thread that handles protocol communication
  52    acp_thread: WeakEntity<acp_thread::AcpThread>,
  53    _subscription: Subscription,
  54}
  55
  56pub struct LanguageModels {
  57    /// Access language model by ID
  58    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
  59    /// Cached list for returning language model information
  60    model_list: acp_thread::AgentModelList,
  61    refresh_models_rx: watch::Receiver<()>,
  62    refresh_models_tx: watch::Sender<()>,
  63}
  64
  65impl LanguageModels {
  66    fn new(cx: &App) -> Self {
  67        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  68        let mut this = Self {
  69            models: HashMap::default(),
  70            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  71            refresh_models_rx,
  72            refresh_models_tx,
  73        };
  74        this.refresh_list(cx);
  75        this
  76    }
  77
  78    fn refresh_list(&mut self, cx: &App) {
  79        let providers = LanguageModelRegistry::global(cx)
  80            .read(cx)
  81            .providers()
  82            .into_iter()
  83            .filter(|provider| provider.is_authenticated(cx))
  84            .collect::<Vec<_>>();
  85
  86        let mut language_model_list = IndexMap::default();
  87        let mut recommended_models = HashSet::default();
  88
  89        let mut recommended = Vec::new();
  90        for provider in &providers {
  91            for model in provider.recommended_models(cx) {
  92                recommended_models.insert(model.id());
  93                recommended.push(Self::map_language_model_to_info(&model, provider));
  94            }
  95        }
  96        if !recommended.is_empty() {
  97            language_model_list.insert(
  98                acp_thread::AgentModelGroupName("Recommended".into()),
  99                recommended,
 100            );
 101        }
 102
 103        let mut models = HashMap::default();
 104        for provider in providers {
 105            let mut provider_models = Vec::new();
 106            for model in provider.provided_models(cx) {
 107                let model_info = Self::map_language_model_to_info(&model, &provider);
 108                let model_id = model_info.id.clone();
 109                if !recommended_models.contains(&model.id()) {
 110                    provider_models.push(model_info);
 111                }
 112                models.insert(model_id, model);
 113            }
 114            if !provider_models.is_empty() {
 115                language_model_list.insert(
 116                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 117                    provider_models,
 118                );
 119            }
 120        }
 121
 122        self.models = models;
 123        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 124        self.refresh_models_tx.send(()).ok();
 125    }
 126
 127    fn watch(&self) -> watch::Receiver<()> {
 128        self.refresh_models_rx.clone()
 129    }
 130
 131    pub fn model_from_id(
 132        &self,
 133        model_id: &acp_thread::AgentModelId,
 134    ) -> Option<Arc<dyn LanguageModel>> {
 135        self.models.get(model_id).cloned()
 136    }
 137
 138    fn map_language_model_to_info(
 139        model: &Arc<dyn LanguageModel>,
 140        provider: &Arc<dyn LanguageModelProvider>,
 141    ) -> acp_thread::AgentModelInfo {
 142        acp_thread::AgentModelInfo {
 143            id: Self::model_id(model),
 144            name: model.name().0,
 145            icon: Some(provider.icon()),
 146        }
 147    }
 148
 149    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
 150        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 151    }
 152}
 153
 154pub struct NativeAgent {
 155    /// Session ID -> Session mapping
 156    sessions: HashMap<acp::SessionId, Session>,
 157    /// Shared project context for all threads
 158    project_context: Entity<ProjectContext>,
 159    project_context_needs_refresh: watch::Sender<()>,
 160    _maintain_project_context: Task<Result<()>>,
 161    context_server_registry: Entity<ContextServerRegistry>,
 162    /// Shared templates for all threads
 163    templates: Arc<Templates>,
 164    /// Cached model information
 165    models: LanguageModels,
 166    project: Entity<Project>,
 167    prompt_store: Option<Entity<PromptStore>>,
 168    fs: Arc<dyn Fs>,
 169    _subscriptions: Vec<Subscription>,
 170}
 171
 172impl NativeAgent {
 173    pub async fn new(
 174        project: Entity<Project>,
 175        templates: Arc<Templates>,
 176        prompt_store: Option<Entity<PromptStore>>,
 177        fs: Arc<dyn Fs>,
 178        cx: &mut AsyncApp,
 179    ) -> Result<Entity<NativeAgent>> {
 180        log::info!("Creating new NativeAgent");
 181
 182        let project_context = cx
 183            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 184            .await;
 185
 186        cx.new(|cx| {
 187            let mut subscriptions = vec![
 188                cx.subscribe(&project, Self::handle_project_event),
 189                cx.subscribe(
 190                    &LanguageModelRegistry::global(cx),
 191                    Self::handle_models_updated_event,
 192                ),
 193            ];
 194            if let Some(prompt_store) = prompt_store.as_ref() {
 195                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 196            }
 197
 198            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 199                watch::channel(());
 200            Self {
 201                sessions: HashMap::new(),
 202                project_context: cx.new(|_| project_context),
 203                project_context_needs_refresh: project_context_needs_refresh_tx,
 204                _maintain_project_context: cx.spawn(async move |this, cx| {
 205                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 206                }),
 207                context_server_registry: cx.new(|cx| {
 208                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 209                }),
 210                templates,
 211                models: LanguageModels::new(cx),
 212                project,
 213                prompt_store,
 214                fs,
 215                _subscriptions: subscriptions,
 216            }
 217        })
 218    }
 219
 220    pub fn models(&self) -> &LanguageModels {
 221        &self.models
 222    }
 223
 224    async fn maintain_project_context(
 225        this: WeakEntity<Self>,
 226        mut needs_refresh: watch::Receiver<()>,
 227        cx: &mut AsyncApp,
 228    ) -> Result<()> {
 229        while needs_refresh.changed().await.is_ok() {
 230            let project_context = this
 231                .update(cx, |this, cx| {
 232                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 233                })?
 234                .await;
 235            this.update(cx, |this, cx| {
 236                this.project_context = cx.new(|_| project_context);
 237            })?;
 238        }
 239
 240        Ok(())
 241    }
 242
 243    fn build_project_context(
 244        project: &Entity<Project>,
 245        prompt_store: Option<&Entity<PromptStore>>,
 246        cx: &mut App,
 247    ) -> Task<ProjectContext> {
 248        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 249        let worktree_tasks = worktrees
 250            .into_iter()
 251            .map(|worktree| {
 252                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 253            })
 254            .collect::<Vec<_>>();
 255        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 256            prompt_store.read_with(cx, |prompt_store, cx| {
 257                let prompts = prompt_store.default_prompt_metadata();
 258                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 259                    let contents = prompt_store.load(prompt_metadata.id, cx);
 260                    async move { (contents.await, prompt_metadata) }
 261                });
 262                cx.background_spawn(future::join_all(load_tasks))
 263            })
 264        } else {
 265            Task::ready(vec![])
 266        };
 267
 268        cx.spawn(async move |_cx| {
 269            let (worktrees, default_user_rules) =
 270                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 271
 272            let worktrees = worktrees
 273                .into_iter()
 274                .map(|(worktree, _rules_error)| {
 275                    // TODO: show error message
 276                    // if let Some(rules_error) = rules_error {
 277                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 278                    // }
 279                    worktree
 280                })
 281                .collect::<Vec<_>>();
 282
 283            let default_user_rules = default_user_rules
 284                .into_iter()
 285                .flat_map(|(contents, prompt_metadata)| match contents {
 286                    Ok(contents) => Some(UserRulesContext {
 287                        uuid: match prompt_metadata.id {
 288                            PromptId::User { uuid } => uuid,
 289                            PromptId::EditWorkflow => return None,
 290                        },
 291                        title: prompt_metadata.title.map(|title| title.to_string()),
 292                        contents,
 293                    }),
 294                    Err(_err) => {
 295                        // TODO: show error message
 296                        // this.update(cx, |_, cx| {
 297                        //     cx.emit(RulesLoadingError {
 298                        //         message: format!("{err:?}").into(),
 299                        //     });
 300                        // })
 301                        // .ok();
 302                        None
 303                    }
 304                })
 305                .collect::<Vec<_>>();
 306
 307            ProjectContext::new(worktrees, default_user_rules)
 308        })
 309    }
 310
 311    fn load_worktree_info_for_system_prompt(
 312        worktree: Entity<Worktree>,
 313        project: Entity<Project>,
 314        cx: &mut App,
 315    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 316        let tree = worktree.read(cx);
 317        let root_name = tree.root_name().into();
 318        let abs_path = tree.abs_path();
 319
 320        let mut context = WorktreeContext {
 321            root_name,
 322            abs_path,
 323            rules_file: None,
 324        };
 325
 326        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 327        let Some(rules_task) = rules_task else {
 328            return Task::ready((context, None));
 329        };
 330
 331        cx.spawn(async move |_| {
 332            let (rules_file, rules_file_error) = match rules_task.await {
 333                Ok(rules_file) => (Some(rules_file), None),
 334                Err(err) => (
 335                    None,
 336                    Some(RulesLoadingError {
 337                        message: format!("{err}").into(),
 338                    }),
 339                ),
 340            };
 341            context.rules_file = rules_file;
 342            (context, rules_file_error)
 343        })
 344    }
 345
 346    fn load_worktree_rules_file(
 347        worktree: Entity<Worktree>,
 348        project: Entity<Project>,
 349        cx: &mut App,
 350    ) -> Option<Task<Result<RulesFileContext>>> {
 351        let worktree = worktree.read(cx);
 352        let worktree_id = worktree.id();
 353        let selected_rules_file = RULES_FILE_NAMES
 354            .into_iter()
 355            .filter_map(|name| {
 356                worktree
 357                    .entry_for_path(name)
 358                    .filter(|entry| entry.is_file())
 359                    .map(|entry| entry.path.clone())
 360            })
 361            .next();
 362
 363        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 364        // supported. This doesn't seem to occur often in GitHub repositories.
 365        selected_rules_file.map(|path_in_worktree| {
 366            let project_path = ProjectPath {
 367                worktree_id,
 368                path: path_in_worktree.clone(),
 369            };
 370            let buffer_task =
 371                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 372            let rope_task = cx.spawn(async move |cx| {
 373                buffer_task.await?.read_with(cx, |buffer, cx| {
 374                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 375                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 376                })?
 377            });
 378            // Build a string from the rope on a background thread.
 379            cx.background_spawn(async move {
 380                let (project_entry_id, rope) = rope_task.await?;
 381                anyhow::Ok(RulesFileContext {
 382                    path_in_worktree,
 383                    text: rope.to_string().trim().to_string(),
 384                    project_entry_id: project_entry_id.to_usize(),
 385                })
 386            })
 387        })
 388    }
 389
 390    fn handle_project_event(
 391        &mut self,
 392        _project: Entity<Project>,
 393        event: &project::Event,
 394        _cx: &mut Context<Self>,
 395    ) {
 396        match event {
 397            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 398                self.project_context_needs_refresh.send(()).ok();
 399            }
 400            project::Event::WorktreeUpdatedEntries(_, items) => {
 401                if items.iter().any(|(path, _, _)| {
 402                    RULES_FILE_NAMES
 403                        .iter()
 404                        .any(|name| path.as_ref() == Path::new(name))
 405                }) {
 406                    self.project_context_needs_refresh.send(()).ok();
 407                }
 408            }
 409            _ => {}
 410        }
 411    }
 412
 413    fn handle_prompts_updated_event(
 414        &mut self,
 415        _prompt_store: Entity<PromptStore>,
 416        _event: &prompt_store::PromptsUpdatedEvent,
 417        _cx: &mut Context<Self>,
 418    ) {
 419        self.project_context_needs_refresh.send(()).ok();
 420    }
 421
 422    fn handle_models_updated_event(
 423        &mut self,
 424        _registry: Entity<LanguageModelRegistry>,
 425        _event: &language_model::Event,
 426        cx: &mut Context<Self>,
 427    ) {
 428        self.models.refresh_list(cx);
 429
 430        let default_model = LanguageModelRegistry::read_global(cx)
 431            .default_model()
 432            .map(|m| m.model.clone());
 433
 434        for session in self.sessions.values_mut() {
 435            session.thread.update(cx, |thread, cx| {
 436                if thread.model().is_none()
 437                    && let Some(model) = default_model.clone()
 438                {
 439                    thread.set_model(model);
 440                    cx.notify();
 441                }
 442            });
 443        }
 444    }
 445}
 446
 447/// Wrapper struct that implements the AgentConnection trait
 448#[derive(Clone)]
 449pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 450
 451impl NativeAgentConnection {
 452    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 453        self.0
 454            .read(cx)
 455            .sessions
 456            .get(session_id)
 457            .map(|session| session.thread.clone())
 458    }
 459
 460    fn run_turn(
 461        &self,
 462        session_id: acp::SessionId,
 463        cx: &mut App,
 464        f: impl 'static
 465        + FnOnce(
 466            Entity<Thread>,
 467            &mut App,
 468        ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
 469    ) -> Task<Result<acp::PromptResponse>> {
 470        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 471            agent
 472                .sessions
 473                .get_mut(&session_id)
 474                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 475        }) else {
 476            return Task::ready(Err(anyhow!("Session not found")));
 477        };
 478        log::debug!("Found session for: {}", session_id);
 479
 480        let mut response_stream = match f(thread, cx) {
 481            Ok(stream) => stream,
 482            Err(err) => return Task::ready(Err(err)),
 483        };
 484        cx.spawn(async move |cx| {
 485            // Handle response stream and forward to session.acp_thread
 486            while let Some(result) = response_stream.next().await {
 487                match result {
 488                    Ok(event) => {
 489                        log::trace!("Received completion event: {:?}", event);
 490
 491                        match event {
 492                            AgentResponseEvent::Text(text) => {
 493                                acp_thread.update(cx, |thread, cx| {
 494                                    thread.push_assistant_content_block(
 495                                        acp::ContentBlock::Text(acp::TextContent {
 496                                            text,
 497                                            annotations: None,
 498                                        }),
 499                                        false,
 500                                        cx,
 501                                    )
 502                                })?;
 503                            }
 504                            AgentResponseEvent::Thinking(text) => {
 505                                acp_thread.update(cx, |thread, cx| {
 506                                    thread.push_assistant_content_block(
 507                                        acp::ContentBlock::Text(acp::TextContent {
 508                                            text,
 509                                            annotations: None,
 510                                        }),
 511                                        true,
 512                                        cx,
 513                                    )
 514                                })?;
 515                            }
 516                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
 517                                tool_call,
 518                                options,
 519                                response,
 520                            }) => {
 521                                let recv = acp_thread.update(cx, |thread, cx| {
 522                                    thread.request_tool_call_authorization(tool_call, options, cx)
 523                                })?;
 524                                cx.background_spawn(async move {
 525                                    if let Some(recv) = recv.log_err()
 526                                        && let Some(option) = recv
 527                                            .await
 528                                            .context("authorization sender was dropped")
 529                                            .log_err()
 530                                    {
 531                                        response
 532                                            .send(option)
 533                                            .map(|_| anyhow!("authorization receiver was dropped"))
 534                                            .log_err();
 535                                    }
 536                                })
 537                                .detach();
 538                            }
 539                            AgentResponseEvent::ToolCall(tool_call) => {
 540                                acp_thread.update(cx, |thread, cx| {
 541                                    thread.upsert_tool_call(tool_call, cx)
 542                                })??;
 543                            }
 544                            AgentResponseEvent::ToolCallUpdate(update) => {
 545                                acp_thread.update(cx, |thread, cx| {
 546                                    thread.update_tool_call(update, cx)
 547                                })??;
 548                            }
 549                            AgentResponseEvent::Stop(stop_reason) => {
 550                                log::debug!("Assistant message complete: {:?}", stop_reason);
 551                                return Ok(acp::PromptResponse { stop_reason });
 552                            }
 553                        }
 554                    }
 555                    Err(e) => {
 556                        log::error!("Error in model response stream: {:?}", e);
 557                        return Err(e);
 558                    }
 559                }
 560            }
 561
 562            log::info!("Response stream completed");
 563            anyhow::Ok(acp::PromptResponse {
 564                stop_reason: acp::StopReason::EndTurn,
 565            })
 566        })
 567    }
 568}
 569
 570impl AgentModelSelector for NativeAgentConnection {
 571    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 572        log::debug!("NativeAgentConnection::list_models called");
 573        let list = self.0.read(cx).models.model_list.clone();
 574        Task::ready(if list.is_empty() {
 575            Err(anyhow::anyhow!("No models available"))
 576        } else {
 577            Ok(list)
 578        })
 579    }
 580
 581    fn select_model(
 582        &self,
 583        session_id: acp::SessionId,
 584        model_id: acp_thread::AgentModelId,
 585        cx: &mut App,
 586    ) -> Task<Result<()>> {
 587        log::info!("Setting model for session {}: {}", session_id, model_id);
 588        let Some(thread) = self
 589            .0
 590            .read(cx)
 591            .sessions
 592            .get(&session_id)
 593            .map(|session| session.thread.clone())
 594        else {
 595            return Task::ready(Err(anyhow!("Session not found")));
 596        };
 597
 598        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
 599            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 600        };
 601
 602        thread.update(cx, |thread, _cx| {
 603            thread.set_model(model.clone());
 604        });
 605
 606        update_settings_file::<AgentSettings>(
 607            self.0.read(cx).fs.clone(),
 608            cx,
 609            move |settings, _cx| {
 610                settings.set_model(model);
 611            },
 612        );
 613
 614        Task::ready(Ok(()))
 615    }
 616
 617    fn selected_model(
 618        &self,
 619        session_id: &acp::SessionId,
 620        cx: &mut App,
 621    ) -> Task<Result<acp_thread::AgentModelInfo>> {
 622        let session_id = session_id.clone();
 623
 624        let Some(thread) = self
 625            .0
 626            .read(cx)
 627            .sessions
 628            .get(&session_id)
 629            .map(|session| session.thread.clone())
 630        else {
 631            return Task::ready(Err(anyhow!("Session not found")));
 632        };
 633        let Some(model) = thread.read(cx).model() else {
 634            return Task::ready(Err(anyhow!("Model not found")));
 635        };
 636        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 637        else {
 638            return Task::ready(Err(anyhow!("Provider not found")));
 639        };
 640        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 641            model, &provider,
 642        )))
 643    }
 644
 645    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
 646        self.0.read(cx).models.watch()
 647    }
 648}
 649
 650impl acp_thread::AgentConnection for NativeAgentConnection {
 651    fn new_thread(
 652        self: Rc<Self>,
 653        project: Entity<Project>,
 654        cwd: &Path,
 655        cx: &mut App,
 656    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 657        let agent = self.0.clone();
 658        log::info!("Creating new thread for project at: {:?}", cwd);
 659
 660        cx.spawn(async move |cx| {
 661            log::debug!("Starting thread creation in async context");
 662
 663            // Generate session ID
 664            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
 665            log::info!("Created session with ID: {}", session_id);
 666
 667            // Create AcpThread
 668            let acp_thread = cx.update(|cx| {
 669                cx.new(|cx| {
 670                    acp_thread::AcpThread::new(
 671                        "agent2",
 672                        self.clone(),
 673                        project.clone(),
 674                        session_id.clone(),
 675                        cx,
 676                    )
 677                })
 678            })?;
 679            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 680
 681            // Create Thread
 682            let thread = agent.update(
 683                cx,
 684                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 685                    // Fetch default model from registry settings
 686                    let registry = LanguageModelRegistry::read_global(cx);
 687
 688                    // Log available models for debugging
 689                    let available_count = registry.available_models(cx).count();
 690                    log::debug!("Total available models: {}", available_count);
 691
 692                    let default_model = registry.default_model().and_then(|default_model| {
 693                        agent
 694                            .models
 695                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 696                    });
 697
 698                    let thread = cx.new(|cx| {
 699                        let mut thread = Thread::new(
 700                            project.clone(),
 701                            agent.project_context.clone(),
 702                            agent.context_server_registry.clone(),
 703                            action_log.clone(),
 704                            agent.templates.clone(),
 705                            default_model,
 706                            cx,
 707                        );
 708                        thread.add_tool(CopyPathTool::new(project.clone()));
 709                        thread.add_tool(CreateDirectoryTool::new(project.clone()));
 710                        thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
 711                        thread.add_tool(DiagnosticsTool::new(project.clone()));
 712                        thread.add_tool(EditFileTool::new(cx.entity()));
 713                        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
 714                        thread.add_tool(FindPathTool::new(project.clone()));
 715                        thread.add_tool(GrepTool::new(project.clone()));
 716                        thread.add_tool(ListDirectoryTool::new(project.clone()));
 717                        thread.add_tool(MovePathTool::new(project.clone()));
 718                        thread.add_tool(NowTool);
 719                        thread.add_tool(OpenTool::new(project.clone()));
 720                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
 721                        thread.add_tool(TerminalTool::new(project.clone(), cx));
 722                        thread.add_tool(ThinkingTool);
 723                        thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 724                        thread
 725                    });
 726
 727                    Ok(thread)
 728                },
 729            )??;
 730
 731            // Store the session
 732            agent.update(cx, |agent, cx| {
 733                agent.sessions.insert(
 734                    session_id,
 735                    Session {
 736                        thread,
 737                        acp_thread: acp_thread.downgrade(),
 738                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 739                            this.sessions.remove(acp_thread.session_id());
 740                        }),
 741                    },
 742                );
 743            })?;
 744
 745            Ok(acp_thread)
 746        })
 747    }
 748
 749    fn auth_methods(&self) -> &[acp::AuthMethod] {
 750        &[] // No auth for in-process
 751    }
 752
 753    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 754        Task::ready(Ok(()))
 755    }
 756
 757    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 758        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
 759    }
 760
 761    fn prompt(
 762        &self,
 763        id: Option<acp_thread::UserMessageId>,
 764        params: acp::PromptRequest,
 765        cx: &mut App,
 766    ) -> Task<Result<acp::PromptResponse>> {
 767        let id = id.expect("UserMessageId is required");
 768        let session_id = params.session_id.clone();
 769        log::info!("Received prompt request for session: {}", session_id);
 770        log::debug!("Prompt blocks count: {}", params.prompt.len());
 771
 772        self.run_turn(session_id, cx, |thread, cx| {
 773            let content: Vec<UserMessageContent> = params
 774                .prompt
 775                .into_iter()
 776                .map(Into::into)
 777                .collect::<Vec<_>>();
 778            log::info!("Converted prompt to message: {} chars", content.len());
 779            log::debug!("Message id: {:?}", id);
 780            log::debug!("Message content: {:?}", content);
 781
 782            thread.update(cx, |thread, cx| thread.send(id, content, cx))
 783        })
 784    }
 785
 786    fn resume(
 787        &self,
 788        session_id: &acp::SessionId,
 789        _cx: &mut App,
 790    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
 791        Some(Rc::new(NativeAgentSessionResume {
 792            connection: self.clone(),
 793            session_id: session_id.clone(),
 794        }) as _)
 795    }
 796
 797    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
 798        log::info!("Cancelling on session: {}", session_id);
 799        self.0.update(cx, |agent, cx| {
 800            if let Some(agent) = agent.sessions.get(session_id) {
 801                agent.thread.update(cx, |thread, _cx| thread.cancel());
 802            }
 803        });
 804    }
 805
 806    fn session_editor(
 807        &self,
 808        session_id: &agent_client_protocol::SessionId,
 809        cx: &mut App,
 810    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
 811        self.0.update(cx, |agent, _cx| {
 812            agent
 813                .sessions
 814                .get(session_id)
 815                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
 816        })
 817    }
 818
 819    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 820        self
 821    }
 822}
 823
 824struct NativeAgentSessionEditor(Entity<Thread>);
 825
 826impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
 827    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
 828        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
 829    }
 830}
 831
 832struct NativeAgentSessionResume {
 833    connection: NativeAgentConnection,
 834    session_id: acp::SessionId,
 835}
 836
 837impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
 838    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
 839        self.connection
 840            .run_turn(self.session_id.clone(), cx, |thread, cx| {
 841                thread.update(cx, |thread, cx| thread.resume(cx))
 842            })
 843    }
 844}
 845
 846#[cfg(test)]
 847mod tests {
 848    use super::*;
 849    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
 850    use fs::FakeFs;
 851    use gpui::TestAppContext;
 852    use serde_json::json;
 853    use settings::SettingsStore;
 854
 855    #[gpui::test]
 856    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
 857        init_test(cx);
 858        let fs = FakeFs::new(cx.executor());
 859        fs.insert_tree(
 860            "/",
 861            json!({
 862                "a": {}
 863            }),
 864        )
 865        .await;
 866        let project = Project::test(fs.clone(), [], cx).await;
 867        let agent = NativeAgent::new(
 868            project.clone(),
 869            Templates::new(),
 870            None,
 871            fs.clone(),
 872            &mut cx.to_async(),
 873        )
 874        .await
 875        .unwrap();
 876        agent.read_with(cx, |agent, cx| {
 877            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
 878        });
 879
 880        let worktree = project
 881            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
 882            .await
 883            .unwrap();
 884        cx.run_until_parked();
 885        agent.read_with(cx, |agent, cx| {
 886            assert_eq!(
 887                agent.project_context.read(cx).worktrees,
 888                vec![WorktreeContext {
 889                    root_name: "a".into(),
 890                    abs_path: Path::new("/a").into(),
 891                    rules_file: None
 892                }]
 893            )
 894        });
 895
 896        // Creating `/a/.rules` updates the project context.
 897        fs.insert_file("/a/.rules", Vec::new()).await;
 898        cx.run_until_parked();
 899        agent.read_with(cx, |agent, cx| {
 900            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
 901            assert_eq!(
 902                agent.project_context.read(cx).worktrees,
 903                vec![WorktreeContext {
 904                    root_name: "a".into(),
 905                    abs_path: Path::new("/a").into(),
 906                    rules_file: Some(RulesFileContext {
 907                        path_in_worktree: Path::new(".rules").into(),
 908                        text: "".into(),
 909                        project_entry_id: rules_entry.id.to_usize()
 910                    })
 911                }]
 912            )
 913        });
 914    }
 915
 916    #[gpui::test]
 917    async fn test_listing_models(cx: &mut TestAppContext) {
 918        init_test(cx);
 919        let fs = FakeFs::new(cx.executor());
 920        fs.insert_tree("/", json!({ "a": {}  })).await;
 921        let project = Project::test(fs.clone(), [], cx).await;
 922        let connection = NativeAgentConnection(
 923            NativeAgent::new(
 924                project.clone(),
 925                Templates::new(),
 926                None,
 927                fs.clone(),
 928                &mut cx.to_async(),
 929            )
 930            .await
 931            .unwrap(),
 932        );
 933
 934        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
 935
 936        let acp_thread::AgentModelList::Grouped(models) = models else {
 937            panic!("Unexpected model group");
 938        };
 939        assert_eq!(
 940            models,
 941            IndexMap::from_iter([(
 942                AgentModelGroupName("Fake".into()),
 943                vec![AgentModelInfo {
 944                    id: AgentModelId("fake/fake".into()),
 945                    name: "Fake".into(),
 946                    icon: Some(ui::IconName::ZedAssistant),
 947                }]
 948            )])
 949        );
 950    }
 951
 952    #[gpui::test]
 953    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
 954        init_test(cx);
 955        let fs = FakeFs::new(cx.executor());
 956        fs.create_dir(paths::settings_file().parent().unwrap())
 957            .await
 958            .unwrap();
 959        fs.insert_file(
 960            paths::settings_file(),
 961            json!({
 962                "agent": {
 963                    "default_model": {
 964                        "provider": "foo",
 965                        "model": "bar"
 966                    }
 967                }
 968            })
 969            .to_string()
 970            .into_bytes(),
 971        )
 972        .await;
 973        let project = Project::test(fs.clone(), [], cx).await;
 974
 975        // Create the agent and connection
 976        let agent = NativeAgent::new(
 977            project.clone(),
 978            Templates::new(),
 979            None,
 980            fs.clone(),
 981            &mut cx.to_async(),
 982        )
 983        .await
 984        .unwrap();
 985        let connection = NativeAgentConnection(agent.clone());
 986
 987        // Create a thread/session
 988        let acp_thread = cx
 989            .update(|cx| {
 990                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
 991            })
 992            .await
 993            .unwrap();
 994
 995        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
 996
 997        // Select a model
 998        let model_id = AgentModelId("fake/fake".into());
 999        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1000            .await
1001            .unwrap();
1002
1003        // Verify the thread has the selected model
1004        agent.read_with(cx, |agent, _| {
1005            let session = agent.sessions.get(&session_id).unwrap();
1006            session.thread.read_with(cx, |thread, _| {
1007                assert_eq!(thread.model().unwrap().id().0, "fake");
1008            });
1009        });
1010
1011        cx.run_until_parked();
1012
1013        // Verify settings file was updated
1014        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1015        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1016
1017        // Check that the agent settings contain the selected model
1018        assert_eq!(
1019            settings_json["agent"]["default_model"]["model"],
1020            json!("fake")
1021        );
1022        assert_eq!(
1023            settings_json["agent"]["default_model"]["provider"],
1024            json!("fake")
1025        );
1026    }
1027
1028    fn init_test(cx: &mut TestAppContext) {
1029        env_logger::try_init().ok();
1030        cx.update(|cx| {
1031            let settings_store = SettingsStore::test(cx);
1032            cx.set_global(settings_store);
1033            Project::init_settings(cx);
1034            agent_settings::init(cx);
1035            language::init(cx);
1036            LanguageModelRegistry::test(cx);
1037        });
1038    }
1039}