agent.rs

   1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
   2use crate::{
   3    AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
   4    DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
   5    MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
   6    ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
   7};
   8use crate::{DbThread, ThreadsDatabase};
   9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
  10use agent_client_protocol as acp;
  11use agent_settings::AgentSettings;
  12use anyhow::{Context as _, Result, anyhow};
  13use collections::{HashSet, IndexMap};
  14use fs::Fs;
  15use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
  16use futures::future::Shared;
  17use futures::{SinkExt, StreamExt, future};
  18use gpui::{
  19    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  20};
  21use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  25};
  26use settings::update_settings_file;
  27use std::any::Any;
  28use std::cell::RefCell;
  29use std::collections::HashMap;
  30use std::path::Path;
  31use std::rc::Rc;
  32use std::sync::Arc;
  33use util::ResultExt;
  34
  35const RULES_FILE_NAMES: [&'static str; 9] = [
  36    ".rules",
  37    ".cursorrules",
  38    ".windsurfrules",
  39    ".clinerules",
  40    ".github/copilot-instructions.md",
  41    "CLAUDE.md",
  42    "AGENT.md",
  43    "AGENTS.md",
  44    "GEMINI.md",
  45];
  46
  47pub struct RulesLoadingError {
  48    pub message: SharedString,
  49}
  50
  51/// Holds both the internal Thread and the AcpThread for a session
  52struct Session {
  53    /// The internal thread that processes messages
  54    thread: Entity<Thread>,
  55    /// The ACP thread that handles protocol communication
  56    acp_thread: WeakEntity<acp_thread::AcpThread>,
  57    _subscription: Subscription,
  58}
  59
  60pub struct LanguageModels {
  61    /// Access language model by ID
  62    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
  63    /// Cached list for returning language model information
  64    model_list: acp_thread::AgentModelList,
  65    refresh_models_rx: watch::Receiver<()>,
  66    refresh_models_tx: watch::Sender<()>,
  67}
  68
  69impl LanguageModels {
  70    fn new(cx: &App) -> Self {
  71        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  72        let mut this = Self {
  73            models: HashMap::default(),
  74            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  75            refresh_models_rx,
  76            refresh_models_tx,
  77        };
  78        this.refresh_list(cx);
  79        this
  80    }
  81
  82    fn refresh_list(&mut self, cx: &App) {
  83        let providers = LanguageModelRegistry::global(cx)
  84            .read(cx)
  85            .providers()
  86            .into_iter()
  87            .filter(|provider| provider.is_authenticated(cx))
  88            .collect::<Vec<_>>();
  89
  90        let mut language_model_list = IndexMap::default();
  91        let mut recommended_models = HashSet::default();
  92
  93        let mut recommended = Vec::new();
  94        for provider in &providers {
  95            for model in provider.recommended_models(cx) {
  96                recommended_models.insert(model.id());
  97                recommended.push(Self::map_language_model_to_info(&model, &provider));
  98            }
  99        }
 100        if !recommended.is_empty() {
 101            language_model_list.insert(
 102                acp_thread::AgentModelGroupName("Recommended".into()),
 103                recommended,
 104            );
 105        }
 106
 107        let mut models = HashMap::default();
 108        for provider in providers {
 109            let mut provider_models = Vec::new();
 110            for model in provider.provided_models(cx) {
 111                let model_info = Self::map_language_model_to_info(&model, &provider);
 112                let model_id = model_info.id.clone();
 113                if !recommended_models.contains(&model.id()) {
 114                    provider_models.push(model_info);
 115                }
 116                models.insert(model_id, model);
 117            }
 118            if !provider_models.is_empty() {
 119                language_model_list.insert(
 120                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 121                    provider_models,
 122                );
 123            }
 124        }
 125
 126        self.models = models;
 127        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 128        self.refresh_models_tx.send(()).ok();
 129    }
 130
 131    fn watch(&self) -> watch::Receiver<()> {
 132        self.refresh_models_rx.clone()
 133    }
 134
 135    pub fn model_from_id(
 136        &self,
 137        model_id: &acp_thread::AgentModelId,
 138    ) -> Option<Arc<dyn LanguageModel>> {
 139        self.models.get(model_id).cloned()
 140    }
 141
 142    fn map_language_model_to_info(
 143        model: &Arc<dyn LanguageModel>,
 144        provider: &Arc<dyn LanguageModelProvider>,
 145    ) -> acp_thread::AgentModelInfo {
 146        acp_thread::AgentModelInfo {
 147            id: Self::model_id(model),
 148            name: model.name().0,
 149            icon: Some(provider.icon()),
 150        }
 151    }
 152
 153    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
 154        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 155    }
 156}
 157
 158pub struct NativeAgent {
 159    /// Session ID -> Session mapping
 160    sessions: HashMap<acp::SessionId, Session>,
 161    /// Shared project context for all threads
 162    project_context: Rc<RefCell<ProjectContext>>,
 163    project_context_needs_refresh: watch::Sender<()>,
 164    _maintain_project_context: Task<Result<()>>,
 165    context_server_registry: Entity<ContextServerRegistry>,
 166    /// Shared templates for all threads
 167    templates: Arc<Templates>,
 168    /// Cached model information
 169    models: LanguageModels,
 170    project: Entity<Project>,
 171    prompt_store: Option<Entity<PromptStore>>,
 172    thread_database: Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 173    history_listeners: Vec<UnboundedSender<Vec<AcpThreadMetadata>>>,
 174    fs: Arc<dyn Fs>,
 175    _subscriptions: Vec<Subscription>,
 176}
 177
 178impl NativeAgent {
 179    pub async fn new(
 180        project: Entity<Project>,
 181        templates: Arc<Templates>,
 182        prompt_store: Option<Entity<PromptStore>>,
 183        fs: Arc<dyn Fs>,
 184        cx: &mut AsyncApp,
 185    ) -> Result<Entity<NativeAgent>> {
 186        log::info!("Creating new NativeAgent");
 187
 188        let project_context = cx
 189            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 190            .await;
 191
 192        cx.new(|cx| {
 193            let mut subscriptions = vec![
 194                cx.subscribe(&project, Self::handle_project_event),
 195                cx.subscribe(
 196                    &LanguageModelRegistry::global(cx),
 197                    Self::handle_models_updated_event,
 198                ),
 199            ];
 200            if let Some(prompt_store) = prompt_store.as_ref() {
 201                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 202            }
 203
 204            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 205                watch::channel(());
 206            Self {
 207                sessions: HashMap::new(),
 208                project_context: Rc::new(RefCell::new(project_context)),
 209                project_context_needs_refresh: project_context_needs_refresh_tx,
 210                _maintain_project_context: cx.spawn(async move |this, cx| {
 211                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 212                }),
 213                context_server_registry: cx.new(|cx| {
 214                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 215                }),
 216                thread_database: ThreadsDatabase::connect(cx),
 217                templates,
 218                models: LanguageModels::new(cx),
 219                project,
 220                prompt_store,
 221                fs,
 222                history_listeners: Vec::new(),
 223                _subscriptions: subscriptions,
 224            }
 225        })
 226    }
 227
 228    pub fn models(&self) -> &LanguageModels {
 229        &self.models
 230    }
 231
 232    async fn maintain_project_context(
 233        this: WeakEntity<Self>,
 234        mut needs_refresh: watch::Receiver<()>,
 235        cx: &mut AsyncApp,
 236    ) -> Result<()> {
 237        while needs_refresh.changed().await.is_ok() {
 238            let project_context = this
 239                .update(cx, |this, cx| {
 240                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 241                })?
 242                .await;
 243            this.update(cx, |this, _| this.project_context.replace(project_context))?;
 244        }
 245
 246        Ok(())
 247    }
 248
 249    fn build_project_context(
 250        project: &Entity<Project>,
 251        prompt_store: Option<&Entity<PromptStore>>,
 252        cx: &mut App,
 253    ) -> Task<ProjectContext> {
 254        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 255        let worktree_tasks = worktrees
 256            .into_iter()
 257            .map(|worktree| {
 258                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 259            })
 260            .collect::<Vec<_>>();
 261        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 262            prompt_store.read_with(cx, |prompt_store, cx| {
 263                let prompts = prompt_store.default_prompt_metadata();
 264                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 265                    let contents = prompt_store.load(prompt_metadata.id, cx);
 266                    async move { (contents.await, prompt_metadata) }
 267                });
 268                cx.background_spawn(future::join_all(load_tasks))
 269            })
 270        } else {
 271            Task::ready(vec![])
 272        };
 273
 274        cx.spawn(async move |_cx| {
 275            let (worktrees, default_user_rules) =
 276                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 277
 278            let worktrees = worktrees
 279                .into_iter()
 280                .map(|(worktree, _rules_error)| {
 281                    // TODO: show error message
 282                    // if let Some(rules_error) = rules_error {
 283                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 284                    // }
 285                    worktree
 286                })
 287                .collect::<Vec<_>>();
 288
 289            let default_user_rules = default_user_rules
 290                .into_iter()
 291                .flat_map(|(contents, prompt_metadata)| match contents {
 292                    Ok(contents) => Some(UserRulesContext {
 293                        uuid: match prompt_metadata.id {
 294                            PromptId::User { uuid } => uuid,
 295                            PromptId::EditWorkflow => return None,
 296                        },
 297                        title: prompt_metadata.title.map(|title| title.to_string()),
 298                        contents,
 299                    }),
 300                    Err(_err) => {
 301                        // TODO: show error message
 302                        // this.update(cx, |_, cx| {
 303                        //     cx.emit(RulesLoadingError {
 304                        //         message: format!("{err:?}").into(),
 305                        //     });
 306                        // })
 307                        // .ok();
 308                        None
 309                    }
 310                })
 311                .collect::<Vec<_>>();
 312
 313            ProjectContext::new(worktrees, default_user_rules)
 314        })
 315    }
 316
 317    fn load_worktree_info_for_system_prompt(
 318        worktree: Entity<Worktree>,
 319        project: Entity<Project>,
 320        cx: &mut App,
 321    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 322        let tree = worktree.read(cx);
 323        let root_name = tree.root_name().into();
 324        let abs_path = tree.abs_path();
 325
 326        let mut context = WorktreeContext {
 327            root_name,
 328            abs_path,
 329            rules_file: None,
 330        };
 331
 332        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 333        let Some(rules_task) = rules_task else {
 334            return Task::ready((context, None));
 335        };
 336
 337        cx.spawn(async move |_| {
 338            let (rules_file, rules_file_error) = match rules_task.await {
 339                Ok(rules_file) => (Some(rules_file), None),
 340                Err(err) => (
 341                    None,
 342                    Some(RulesLoadingError {
 343                        message: format!("{err}").into(),
 344                    }),
 345                ),
 346            };
 347            context.rules_file = rules_file;
 348            (context, rules_file_error)
 349        })
 350    }
 351
 352    fn load_worktree_rules_file(
 353        worktree: Entity<Worktree>,
 354        project: Entity<Project>,
 355        cx: &mut App,
 356    ) -> Option<Task<Result<RulesFileContext>>> {
 357        let worktree = worktree.read(cx);
 358        let worktree_id = worktree.id();
 359        let selected_rules_file = RULES_FILE_NAMES
 360            .into_iter()
 361            .filter_map(|name| {
 362                worktree
 363                    .entry_for_path(name)
 364                    .filter(|entry| entry.is_file())
 365                    .map(|entry| entry.path.clone())
 366            })
 367            .next();
 368
 369        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 370        // supported. This doesn't seem to occur often in GitHub repositories.
 371        selected_rules_file.map(|path_in_worktree| {
 372            let project_path = ProjectPath {
 373                worktree_id,
 374                path: path_in_worktree.clone(),
 375            };
 376            let buffer_task =
 377                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 378            let rope_task = cx.spawn(async move |cx| {
 379                buffer_task.await?.read_with(cx, |buffer, cx| {
 380                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 381                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 382                })?
 383            });
 384            // Build a string from the rope on a background thread.
 385            cx.background_spawn(async move {
 386                let (project_entry_id, rope) = rope_task.await?;
 387                anyhow::Ok(RulesFileContext {
 388                    path_in_worktree,
 389                    text: rope.to_string().trim().to_string(),
 390                    project_entry_id: project_entry_id.to_usize(),
 391                })
 392            })
 393        })
 394    }
 395
 396    fn handle_project_event(
 397        &mut self,
 398        _project: Entity<Project>,
 399        event: &project::Event,
 400        _cx: &mut Context<Self>,
 401    ) {
 402        match event {
 403            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 404                self.project_context_needs_refresh.send(()).ok();
 405            }
 406            project::Event::WorktreeUpdatedEntries(_, items) => {
 407                if items.iter().any(|(path, _, _)| {
 408                    RULES_FILE_NAMES
 409                        .iter()
 410                        .any(|name| path.as_ref() == Path::new(name))
 411                }) {
 412                    self.project_context_needs_refresh.send(()).ok();
 413                }
 414            }
 415            _ => {}
 416        }
 417    }
 418
 419    fn handle_prompts_updated_event(
 420        &mut self,
 421        _prompt_store: Entity<PromptStore>,
 422        _event: &prompt_store::PromptsUpdatedEvent,
 423        _cx: &mut Context<Self>,
 424    ) {
 425        self.project_context_needs_refresh.send(()).ok();
 426    }
 427
 428    fn handle_models_updated_event(
 429        &mut self,
 430        _registry: Entity<LanguageModelRegistry>,
 431        _event: &language_model::Event,
 432        cx: &mut Context<Self>,
 433    ) {
 434        self.models.refresh_list(cx);
 435        for session in self.sessions.values_mut() {
 436            session.thread.update(cx, |thread, _| {
 437                let model_id = LanguageModels::model_id(&thread.model());
 438                if let Some(model) = self.models.model_from_id(&model_id) {
 439                    thread.set_model(model.clone());
 440                }
 441            });
 442        }
 443    }
 444}
 445
 446/// Wrapper struct that implements the AgentConnection trait
 447#[derive(Clone)]
 448pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 449
 450impl NativeAgentConnection {
 451    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 452        self.0
 453            .read(cx)
 454            .sessions
 455            .get(session_id)
 456            .map(|session| session.thread.clone())
 457    }
 458
 459    fn run_turn(
 460        &self,
 461        session_id: acp::SessionId,
 462        cx: &mut App,
 463        f: impl 'static
 464        + FnOnce(
 465            Entity<Thread>,
 466            &mut App,
 467        ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
 468    ) -> Task<Result<acp::PromptResponse>> {
 469        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 470            agent
 471                .sessions
 472                .get_mut(&session_id)
 473                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 474        }) else {
 475            return Task::ready(Err(anyhow!("Session not found")));
 476        };
 477        log::debug!("Found session for: {}", session_id);
 478
 479        let mut response_stream = match f(thread, cx) {
 480            Ok(stream) => stream,
 481            Err(err) => return Task::ready(Err(err)),
 482        };
 483        cx.spawn(async move |cx| {
 484            // Handle response stream and forward to session.acp_thread
 485            while let Some(result) = response_stream.next().await {
 486                match result {
 487                    Ok(event) => {
 488                        log::trace!("Received completion event: {:?}", event);
 489
 490                        match event {
 491                            AgentResponseEvent::Text(text) => {
 492                                acp_thread.update(cx, |thread, cx| {
 493                                    thread.push_assistant_content_block(
 494                                        acp::ContentBlock::Text(acp::TextContent {
 495                                            text,
 496                                            annotations: None,
 497                                        }),
 498                                        false,
 499                                        cx,
 500                                    )
 501                                })?;
 502                            }
 503                            AgentResponseEvent::Thinking(text) => {
 504                                acp_thread.update(cx, |thread, cx| {
 505                                    thread.push_assistant_content_block(
 506                                        acp::ContentBlock::Text(acp::TextContent {
 507                                            text,
 508                                            annotations: None,
 509                                        }),
 510                                        true,
 511                                        cx,
 512                                    )
 513                                })?;
 514                            }
 515                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
 516                                tool_call,
 517                                options,
 518                                response,
 519                            }) => {
 520                                let recv = acp_thread.update(cx, |thread, cx| {
 521                                    thread.request_tool_call_authorization(tool_call, options, cx)
 522                                })?;
 523                                cx.background_spawn(async move {
 524                                    if let Some(option) = recv
 525                                        .await
 526                                        .context("authorization sender was dropped")
 527                                        .log_err()
 528                                    {
 529                                        response
 530                                            .send(option)
 531                                            .map(|_| anyhow!("authorization receiver was dropped"))
 532                                            .log_err();
 533                                    }
 534                                })
 535                                .detach();
 536                            }
 537                            AgentResponseEvent::ToolCall(tool_call) => {
 538                                acp_thread.update(cx, |thread, cx| {
 539                                    thread.upsert_tool_call(tool_call, cx)
 540                                })?;
 541                            }
 542                            AgentResponseEvent::ToolCallUpdate(update) => {
 543                                acp_thread.update(cx, |thread, cx| {
 544                                    thread.update_tool_call(update, cx)
 545                                })??;
 546                            }
 547                            AgentResponseEvent::Stop(stop_reason) => {
 548                                log::debug!("Assistant message complete: {:?}", stop_reason);
 549                                return Ok(acp::PromptResponse { stop_reason });
 550                            }
 551                        }
 552                    }
 553                    Err(e) => {
 554                        log::error!("Error in model response stream: {:?}", e);
 555                        return Err(e);
 556                    }
 557                }
 558            }
 559
 560            log::info!("Response stream completed");
 561            anyhow::Ok(acp::PromptResponse {
 562                stop_reason: acp::StopReason::EndTurn,
 563            })
 564        })
 565    }
 566
 567    fn register_tools(
 568        thread: &mut Thread,
 569        project: Entity<Project>,
 570        action_log: Entity<action_log::ActionLog>,
 571        cx: &mut Context<Thread>,
 572    ) {
 573        thread.add_tool(CopyPathTool::new(project.clone()));
 574        thread.add_tool(CreateDirectoryTool::new(project.clone()));
 575        thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
 576        thread.add_tool(DiagnosticsTool::new(project.clone()));
 577        thread.add_tool(EditFileTool::new(cx.entity()));
 578        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
 579        thread.add_tool(FindPathTool::new(project.clone()));
 580        thread.add_tool(GrepTool::new(project.clone()));
 581        thread.add_tool(ListDirectoryTool::new(project.clone()));
 582        thread.add_tool(MovePathTool::new(project.clone()));
 583        thread.add_tool(NowTool);
 584        thread.add_tool(OpenTool::new(project.clone()));
 585        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
 586        thread.add_tool(TerminalTool::new(project.clone(), cx));
 587        thread.add_tool(ThinkingTool);
 588        thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 589    }
 590}
 591
 592impl AgentModelSelector for NativeAgentConnection {
 593    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 594        log::debug!("NativeAgentConnection::list_models called");
 595        let list = self.0.read(cx).models.model_list.clone();
 596        Task::ready(if list.is_empty() {
 597            Err(anyhow::anyhow!("No models available"))
 598        } else {
 599            Ok(list)
 600        })
 601    }
 602
 603    fn select_model(
 604        &self,
 605        session_id: acp::SessionId,
 606        model_id: acp_thread::AgentModelId,
 607        cx: &mut App,
 608    ) -> Task<Result<()>> {
 609        log::info!("Setting model for session {}: {}", session_id, model_id);
 610        let Some(thread) = self
 611            .0
 612            .read(cx)
 613            .sessions
 614            .get(&session_id)
 615            .map(|session| session.thread.clone())
 616        else {
 617            return Task::ready(Err(anyhow!("Session not found")));
 618        };
 619
 620        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
 621            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 622        };
 623
 624        thread.update(cx, |thread, _cx| {
 625            thread.set_model(model.clone());
 626        });
 627
 628        update_settings_file::<AgentSettings>(
 629            self.0.read(cx).fs.clone(),
 630            cx,
 631            move |settings, _cx| {
 632                settings.set_model(model);
 633            },
 634        );
 635
 636        Task::ready(Ok(()))
 637    }
 638
 639    fn selected_model(
 640        &self,
 641        session_id: &acp::SessionId,
 642        cx: &mut App,
 643    ) -> Task<Result<acp_thread::AgentModelInfo>> {
 644        let session_id = session_id.clone();
 645
 646        let Some(thread) = self
 647            .0
 648            .read(cx)
 649            .sessions
 650            .get(&session_id)
 651            .map(|session| session.thread.clone())
 652        else {
 653            return Task::ready(Err(anyhow!("Session not found")));
 654        };
 655        let model = thread.read(cx).model().clone();
 656        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 657        else {
 658            return Task::ready(Err(anyhow!("Provider not found")));
 659        };
 660        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 661            &model, &provider,
 662        )))
 663    }
 664
 665    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
 666        self.0.read(cx).models.watch()
 667    }
 668}
 669
 670impl acp_thread::AgentConnection for NativeAgentConnection {
 671    fn new_thread(
 672        self: Rc<Self>,
 673        project: Entity<Project>,
 674        cwd: &Path,
 675        cx: &mut App,
 676    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 677        let agent = self.0.clone();
 678        log::info!("Creating new thread for project at: {:?}", cwd);
 679
 680        cx.spawn(async move |cx| {
 681            log::debug!("Starting thread creation in async context");
 682
 683            // Generate session ID
 684            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
 685            log::info!("Created session with ID: {}", session_id);
 686
 687            // Create AcpThread
 688            let acp_thread = cx.update(|cx| {
 689                cx.new(|cx| {
 690                    acp_thread::AcpThread::new(
 691                        "agent2",
 692                        self.clone(),
 693                        project.clone(),
 694                        session_id.clone(),
 695                        cx,
 696                    )
 697                })
 698            })?;
 699            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 700
 701            // Create Thread
 702            let thread = agent.update(
 703                cx,
 704                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 705                    // Fetch default model from registry settings
 706                    let registry = LanguageModelRegistry::read_global(cx);
 707
 708                    // Log available models for debugging
 709                    let available_count = registry.available_models(cx).count();
 710                    log::debug!("Total available models: {}", available_count);
 711
 712                    let default_model = registry
 713                        .default_model()
 714                        .and_then(|default_model| {
 715                            agent
 716                                .models
 717                                .model_from_id(&LanguageModels::model_id(&default_model.model))
 718                        })
 719                        .ok_or_else(|| {
 720                            log::warn!("No default model configured in settings");
 721                            anyhow!(
 722                                "No default model. Please configure a default model in settings."
 723                            )
 724                        })?;
 725
 726                    let thread = cx.new(|cx| {
 727                        let mut thread = Thread::new(
 728                            project.clone(),
 729                            agent.project_context.clone(),
 730                            agent.context_server_registry.clone(),
 731                            action_log.clone(),
 732                            agent.templates.clone(),
 733                            default_model,
 734                            cx,
 735                        );
 736                        Self::register_tools(&mut thread, project, action_log, cx);
 737                        thread
 738                    });
 739
 740                    Ok(thread)
 741                },
 742            )??;
 743
 744            // Store the session
 745            agent.update(cx, |agent, cx| {
 746                agent.sessions.insert(
 747                    session_id,
 748                    Session {
 749                        thread,
 750                        acp_thread: acp_thread.downgrade(),
 751                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 752                            this.sessions.remove(acp_thread.session_id());
 753                        }),
 754                    },
 755                );
 756            })?;
 757
 758            Ok(acp_thread)
 759        })
 760    }
 761
 762    fn auth_methods(&self) -> &[acp::AuthMethod] {
 763        &[] // No auth for in-process
 764    }
 765
 766    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 767        Task::ready(Ok(()))
 768    }
 769
 770    fn list_threads(&self, cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
 771        let (mut tx, rx) = futures::channel::mpsc::unbounded();
 772        let database = self.0.update(cx, |this, _| {
 773            this.history_listeners.push(tx.clone());
 774            this.thread_database.clone()
 775        });
 776        cx.background_executor()
 777            .spawn(async move {
 778                dbg!("listing!");
 779                let database = database.await.map_err(|e| anyhow!(e))?;
 780                let results = database.list_threads().await?;
 781
 782                dbg!(&results);
 783                tx.send(
 784                    results
 785                        .into_iter()
 786                        .map(|thread| AcpThreadMetadata {
 787                            agent: NATIVE_AGENT_SERVER_NAME.clone(),
 788                            id: thread.id,
 789                            title: thread.title,
 790                            updated_at: thread.updated_at,
 791                        })
 792                        .collect(),
 793                )
 794                .await?;
 795                anyhow::Ok(())
 796            })
 797            .detach_and_log_err(cx);
 798        Some(rx)
 799    }
 800
 801    fn load_thread(
 802        self: Rc<Self>,
 803        project: Entity<Project>,
 804        cwd: &Path,
 805        session_id: acp::SessionId,
 806        cx: &mut App,
 807    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 808        let database = self.0.update(cx, |this, _| this.thread_database.clone());
 809        cx.spawn(async move |cx| {
 810            let database = database.await.map_err(|e| anyhow!(e))?;
 811            let db_thread = database
 812                .load_thread(session_id.clone())
 813                .await?
 814                .context("no such thread found")?;
 815
 816            let acp_thread = cx.update(|cx| {
 817                cx.new(|cx| {
 818                    acp_thread::AcpThread::new(
 819                        db_thread.title,
 820                        self.clone(),
 821                        project.clone(),
 822                        session_id.clone(),
 823                        cx,
 824                    )
 825                })
 826            })?;
 827            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 828            let agent = self.0.clone();
 829
 830            // Create Thread
 831            let thread = agent.update(
 832                cx,
 833                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 834                    let configured_model = LanguageModelRegistry::global(cx)
 835                        .update(cx, |registry, cx| {
 836                            db_thread
 837                                .model
 838                                .and_then(|model| {
 839                                    let model = SelectedModel {
 840                                        provider: model.provider.clone().into(),
 841                                        model: model.model.clone().into(),
 842                                    };
 843                                    registry.select_model(&model, cx)
 844                                })
 845                                .or_else(|| registry.default_model())
 846                        })
 847                        .context("no default model configured")?;
 848
 849                    let model = agent
 850                        .models
 851                        .model_from_id(&LanguageModels::model_id(&configured_model.model))
 852                        .context("no model by id")?;
 853
 854                    let thread = cx.new(|cx| {
 855                        let mut thread = Thread::new(
 856                            project.clone(),
 857                            agent.project_context.clone(),
 858                            agent.context_server_registry.clone(),
 859                            action_log.clone(),
 860                            agent.templates.clone(),
 861                            model,
 862                            cx,
 863                        );
 864                        Self::register_tools(&mut thread, project, action_log, cx);
 865                        thread
 866                    });
 867
 868                    Ok(thread)
 869                },
 870            )??;
 871
 872            // Store the session
 873            agent.update(cx, |agent, cx| {
 874                agent.sessions.insert(
 875                    session_id,
 876                    Session {
 877                        thread,
 878                        acp_thread: acp_thread.downgrade(),
 879                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 880                            this.sessions.remove(acp_thread.session_id());
 881                        }),
 882                    },
 883                );
 884            })?;
 885
 886            // we need to actually deserialize the DbThread.
 887            todo!()
 888
 889            Ok(acp_thread)
 890        })
 891    }
 892
 893    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 894        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
 895    }
 896
 897    fn prompt(
 898        &self,
 899        id: Option<acp_thread::UserMessageId>,
 900        params: acp::PromptRequest,
 901        cx: &mut App,
 902    ) -> Task<Result<acp::PromptResponse>> {
 903        let id = id.expect("UserMessageId is required");
 904        let session_id = params.session_id.clone();
 905        log::info!("Received prompt request for session: {}", session_id);
 906        log::debug!("Prompt blocks count: {}", params.prompt.len());
 907
 908        self.run_turn(session_id, cx, |thread, cx| {
 909            let content: Vec<UserMessageContent> = params
 910                .prompt
 911                .into_iter()
 912                .map(Into::into)
 913                .collect::<Vec<_>>();
 914            log::info!("Converted prompt to message: {} chars", content.len());
 915            log::debug!("Message id: {:?}", id);
 916            log::debug!("Message content: {:?}", content);
 917
 918            Ok(thread.update(cx, |thread, cx| {
 919                log::info!(
 920                    "Sending message to thread with model: {:?}",
 921                    thread.model().name()
 922                );
 923                thread.send(id, content, cx)
 924            }))
 925        })
 926    }
 927
 928    fn resume(
 929        &self,
 930        session_id: &acp::SessionId,
 931        _cx: &mut App,
 932    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
 933        Some(Rc::new(NativeAgentSessionResume {
 934            connection: self.clone(),
 935            session_id: session_id.clone(),
 936        }) as _)
 937    }
 938
 939    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
 940        log::info!("Cancelling on session: {}", session_id);
 941        self.0.update(cx, |agent, cx| {
 942            if let Some(agent) = agent.sessions.get(session_id) {
 943                agent.thread.update(cx, |thread, _cx| thread.cancel());
 944            }
 945        });
 946    }
 947
 948    fn session_editor(
 949        &self,
 950        session_id: &agent_client_protocol::SessionId,
 951        cx: &mut App,
 952    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
 953        self.0.update(cx, |agent, _cx| {
 954            agent
 955                .sessions
 956                .get(session_id)
 957                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
 958        })
 959    }
 960
 961    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 962        self
 963    }
 964}
 965
 966struct NativeAgentSessionEditor(Entity<Thread>);
 967
 968impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
 969    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
 970        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
 971    }
 972}
 973
 974struct NativeAgentSessionResume {
 975    connection: NativeAgentConnection,
 976    session_id: acp::SessionId,
 977}
 978
 979impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
 980    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
 981        self.connection
 982            .run_turn(self.session_id.clone(), cx, |thread, cx| {
 983                thread.update(cx, |thread, cx| thread.resume(cx))
 984            })
 985    }
 986}
 987
 988#[cfg(test)]
 989mod tests {
 990    use super::*;
 991    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
 992    use fs::FakeFs;
 993    use gpui::TestAppContext;
 994    use serde_json::json;
 995    use settings::SettingsStore;
 996
 997    #[gpui::test]
 998    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
 999        init_test(cx);
1000        let fs = FakeFs::new(cx.executor());
1001        fs.insert_tree(
1002            "/",
1003            json!({
1004                "a": {}
1005            }),
1006        )
1007        .await;
1008        let project = Project::test(fs.clone(), [], cx).await;
1009        let agent = NativeAgent::new(
1010            project.clone(),
1011            Templates::new(),
1012            None,
1013            fs.clone(),
1014            &mut cx.to_async(),
1015        )
1016        .await
1017        .unwrap();
1018        agent.read_with(cx, |agent, _| {
1019            assert_eq!(agent.project_context.borrow().worktrees, vec![])
1020        });
1021
1022        let worktree = project
1023            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1024            .await
1025            .unwrap();
1026        cx.run_until_parked();
1027        agent.read_with(cx, |agent, _| {
1028            assert_eq!(
1029                agent.project_context.borrow().worktrees,
1030                vec![WorktreeContext {
1031                    root_name: "a".into(),
1032                    abs_path: Path::new("/a").into(),
1033                    rules_file: None
1034                }]
1035            )
1036        });
1037
1038        // Creating `/a/.rules` updates the project context.
1039        fs.insert_file("/a/.rules", Vec::new()).await;
1040        cx.run_until_parked();
1041        agent.read_with(cx, |agent, cx| {
1042            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1043            assert_eq!(
1044                agent.project_context.borrow().worktrees,
1045                vec![WorktreeContext {
1046                    root_name: "a".into(),
1047                    abs_path: Path::new("/a").into(),
1048                    rules_file: Some(RulesFileContext {
1049                        path_in_worktree: Path::new(".rules").into(),
1050                        text: "".into(),
1051                        project_entry_id: rules_entry.id.to_usize()
1052                    })
1053                }]
1054            )
1055        });
1056    }
1057
1058    #[gpui::test]
1059    async fn test_listing_models(cx: &mut TestAppContext) {
1060        init_test(cx);
1061        let fs = FakeFs::new(cx.executor());
1062        fs.insert_tree("/", json!({ "a": {}  })).await;
1063        let project = Project::test(fs.clone(), [], cx).await;
1064        let connection = NativeAgentConnection(
1065            NativeAgent::new(
1066                project.clone(),
1067                Templates::new(),
1068                None,
1069                fs.clone(),
1070                &mut cx.to_async(),
1071            )
1072            .await
1073            .unwrap(),
1074        );
1075
1076        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1077
1078        let acp_thread::AgentModelList::Grouped(models) = models else {
1079            panic!("Unexpected model group");
1080        };
1081        assert_eq!(
1082            models,
1083            IndexMap::from_iter([(
1084                AgentModelGroupName("Fake".into()),
1085                vec![AgentModelInfo {
1086                    id: AgentModelId("fake/fake".into()),
1087                    name: "Fake".into(),
1088                    icon: Some(ui::IconName::ZedAssistant),
1089                }]
1090            )])
1091        );
1092    }
1093
1094    #[gpui::test]
1095    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1096        init_test(cx);
1097        let fs = FakeFs::new(cx.executor());
1098        fs.create_dir(paths::settings_file().parent().unwrap())
1099            .await
1100            .unwrap();
1101        fs.insert_file(
1102            paths::settings_file(),
1103            json!({
1104                "agent": {
1105                    "default_model": {
1106                        "provider": "foo",
1107                        "model": "bar"
1108                    }
1109                }
1110            })
1111            .to_string()
1112            .into_bytes(),
1113        )
1114        .await;
1115        let project = Project::test(fs.clone(), [], cx).await;
1116
1117        // Create the agent and connection
1118        let agent = NativeAgent::new(
1119            project.clone(),
1120            Templates::new(),
1121            None,
1122            fs.clone(),
1123            &mut cx.to_async(),
1124        )
1125        .await
1126        .unwrap();
1127        let connection = NativeAgentConnection(agent.clone());
1128
1129        // Create a thread/session
1130        let acp_thread = cx
1131            .update(|cx| {
1132                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1133            })
1134            .await
1135            .unwrap();
1136
1137        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1138
1139        // Select a model
1140        let model_id = AgentModelId("fake/fake".into());
1141        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1142            .await
1143            .unwrap();
1144
1145        // Verify the thread has the selected model
1146        agent.read_with(cx, |agent, _| {
1147            let session = agent.sessions.get(&session_id).unwrap();
1148            session.thread.read_with(cx, |thread, _| {
1149                assert_eq!(thread.model().id().0, "fake");
1150            });
1151        });
1152
1153        cx.run_until_parked();
1154
1155        // Verify settings file was updated
1156        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1157        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1158
1159        // Check that the agent settings contain the selected model
1160        assert_eq!(
1161            settings_json["agent"]["default_model"]["model"],
1162            json!("fake")
1163        );
1164        assert_eq!(
1165            settings_json["agent"]["default_model"]["provider"],
1166            json!("fake")
1167        );
1168    }
1169
1170    fn init_test(cx: &mut TestAppContext) {
1171        env_logger::try_init().ok();
1172        cx.update(|cx| {
1173            let settings_store = SettingsStore::test(cx);
1174            cx.set_global(settings_store);
1175            Project::init_settings(cx);
1176            agent_settings::init(cx);
1177            language::init(cx);
1178            LanguageModelRegistry::test(cx);
1179        });
1180    }
1181}