1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
4 EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
6 UserMessageContent, WebSearchTool, templates::Templates,
7};
8use crate::{ThreadsDatabase, generate_session_id};
9use acp_thread::{AcpThread, AcpThreadMetadata, AgentHistory, AgentModelSelector};
10use agent_client_protocol as acp;
11use agent_settings::AgentSettings;
12use anyhow::{Context as _, Result, anyhow};
13use collections::{HashSet, IndexMap};
14use fs::Fs;
15use futures::channel::mpsc;
16use futures::{StreamExt, future};
17use gpui::{
18 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
19};
20use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
21use project::{Project, ProjectItem, ProjectPath, Worktree};
22use prompt_store::{
23 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
24};
25use settings::update_settings_file;
26use std::any::Any;
27use std::cell::RefCell;
28use std::collections::HashMap;
29use std::path::Path;
30use std::rc::Rc;
31use std::sync::Arc;
32use std::time::Duration;
33use util::ResultExt;
34
35const RULES_FILE_NAMES: [&'static str; 9] = [
36 ".rules",
37 ".cursorrules",
38 ".windsurfrules",
39 ".clinerules",
40 ".github/copilot-instructions.md",
41 "CLAUDE.md",
42 "AGENT.md",
43 "AGENTS.md",
44 "GEMINI.md",
45];
46
47const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
48
49pub struct RulesLoadingError {
50 pub message: SharedString,
51}
52
53/// Holds both the internal Thread and the AcpThread for a session
54struct Session {
55 /// The internal thread that processes messages
56 thread: Entity<Thread>,
57 /// The ACP thread that handles protocol communication
58 acp_thread: WeakEntity<acp_thread::AcpThread>,
59 save_task: Task<()>,
60 _subscriptions: Vec<Subscription>,
61}
62
63pub struct LanguageModels {
64 /// Access language model by ID
65 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
66 /// Cached list for returning language model information
67 model_list: acp_thread::AgentModelList,
68 refresh_models_rx: watch::Receiver<()>,
69 refresh_models_tx: watch::Sender<()>,
70}
71
72impl LanguageModels {
73 fn new(cx: &App) -> Self {
74 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
75 let mut this = Self {
76 models: HashMap::default(),
77 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
78 refresh_models_rx,
79 refresh_models_tx,
80 };
81 this.refresh_list(cx);
82 this
83 }
84
85 fn refresh_list(&mut self, cx: &App) {
86 let providers = LanguageModelRegistry::global(cx)
87 .read(cx)
88 .providers()
89 .into_iter()
90 .filter(|provider| provider.is_authenticated(cx))
91 .collect::<Vec<_>>();
92
93 let mut language_model_list = IndexMap::default();
94 let mut recommended_models = HashSet::default();
95
96 let mut recommended = Vec::new();
97 for provider in &providers {
98 for model in provider.recommended_models(cx) {
99 recommended_models.insert(model.id());
100 recommended.push(Self::map_language_model_to_info(&model, &provider));
101 }
102 }
103 if !recommended.is_empty() {
104 language_model_list.insert(
105 acp_thread::AgentModelGroupName("Recommended".into()),
106 recommended,
107 );
108 }
109
110 let mut models = HashMap::default();
111 for provider in providers {
112 let mut provider_models = Vec::new();
113 for model in provider.provided_models(cx) {
114 let model_info = Self::map_language_model_to_info(&model, &provider);
115 let model_id = model_info.id.clone();
116 if !recommended_models.contains(&model.id()) {
117 provider_models.push(model_info);
118 }
119 models.insert(model_id, model);
120 }
121 if !provider_models.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName(provider.name().0.clone()),
124 provider_models,
125 );
126 }
127 }
128
129 self.models = models;
130 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
131 self.refresh_models_tx.send(()).ok();
132 }
133
134 fn watch(&self) -> watch::Receiver<()> {
135 self.refresh_models_rx.clone()
136 }
137
138 pub fn model_from_id(
139 &self,
140 model_id: &acp_thread::AgentModelId,
141 ) -> Option<Arc<dyn LanguageModel>> {
142 self.models.get(model_id).cloned()
143 }
144
145 fn map_language_model_to_info(
146 model: &Arc<dyn LanguageModel>,
147 provider: &Arc<dyn LanguageModelProvider>,
148 ) -> acp_thread::AgentModelInfo {
149 acp_thread::AgentModelInfo {
150 id: Self::model_id(model),
151 name: model.name().0,
152 icon: Some(provider.icon()),
153 }
154 }
155
156 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
157 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
158 }
159}
160
161pub struct NativeAgent {
162 /// Session ID -> Session mapping
163 sessions: HashMap<acp::SessionId, Session>,
164 /// Shared project context for all threads
165 project_context: Rc<RefCell<ProjectContext>>,
166 project_context_needs_refresh: watch::Sender<()>,
167 _maintain_project_context: Task<Result<()>>,
168 context_server_registry: Entity<ContextServerRegistry>,
169 /// Shared templates for all threads
170 templates: Arc<Templates>,
171 /// Cached model information
172 models: LanguageModels,
173 project: Entity<Project>,
174 prompt_store: Option<Entity<PromptStore>>,
175 thread_database: Arc<ThreadsDatabase>,
176 history_watchers: Vec<mpsc::UnboundedSender<AcpThreadMetadata>>,
177 fs: Arc<dyn Fs>,
178 _subscriptions: Vec<Subscription>,
179}
180
181impl NativeAgent {
182 pub async fn new(
183 project: Entity<Project>,
184 templates: Arc<Templates>,
185 prompt_store: Option<Entity<PromptStore>>,
186 fs: Arc<dyn Fs>,
187 cx: &mut AsyncApp,
188 ) -> Result<Entity<NativeAgent>> {
189 log::info!("Creating new NativeAgent");
190
191 let project_context = cx
192 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
193 .await;
194
195 let thread_database = cx
196 .update(|cx| ThreadsDatabase::connect(cx))?
197 .await
198 .map_err(|e| anyhow!(e))?;
199
200 cx.new(|cx| {
201 let mut subscriptions = vec![
202 cx.subscribe(&project, Self::handle_project_event),
203 cx.subscribe(
204 &LanguageModelRegistry::global(cx),
205 Self::handle_models_updated_event,
206 ),
207 ];
208 if let Some(prompt_store) = prompt_store.as_ref() {
209 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
210 }
211
212 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
213 watch::channel(());
214 Self {
215 sessions: HashMap::new(),
216 project_context: Rc::new(RefCell::new(project_context)),
217 project_context_needs_refresh: project_context_needs_refresh_tx,
218 _maintain_project_context: cx.spawn(async move |this, cx| {
219 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
220 }),
221 context_server_registry: cx.new(|cx| {
222 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
223 }),
224 thread_database,
225 templates,
226 models: LanguageModels::new(cx),
227 project,
228 prompt_store,
229 fs,
230 history_watchers: Vec::new(),
231 _subscriptions: subscriptions,
232 }
233 })
234 }
235
236 pub fn insert_session(
237 &mut self,
238 thread: Entity<Thread>,
239 acp_thread: Entity<AcpThread>,
240 cx: &mut Context<Self>,
241 ) {
242 let id = thread.read(cx).id().clone();
243 let weak_thread = acp_thread.downgrade();
244 self.sessions.insert(
245 id,
246 Session {
247 thread: thread.clone(),
248 acp_thread: weak_thread.clone(),
249 save_task: Task::ready(()),
250 _subscriptions: vec![
251 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
252 this.sessions.remove(acp_thread.session_id());
253 }),
254 cx.observe(&thread, move |this, thread, cx| {
255 if let Some(response_stream) =
256 thread.update(cx, |thread, cx| thread.generate_title_if_needed(cx))
257 {
258 NativeAgentConnection::handle_thread_events(
259 response_stream,
260 weak_thread.clone(),
261 cx,
262 )
263 .detach_and_log_err(cx);
264 }
265 this.save_thread(thread.clone(), cx)
266 }),
267 ],
268 },
269 );
270 }
271
272 fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
273 let thread = thread_handle.read(cx);
274 let id = thread.id().clone();
275 let Some(session) = self.sessions.get_mut(&id) else {
276 return;
277 };
278
279 let thread = thread_handle.downgrade();
280 let thread_database = self.thread_database.clone();
281 session.save_task = cx.spawn(async move |this, cx| {
282 cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
283
284 let Some(task) = thread.update(cx, |thread, cx| thread.to_db(cx)).ok() else {
285 return;
286 };
287 let db_thread = task.await;
288 let metadata = thread_database
289 .save_thread(id.clone(), db_thread)
290 .await
291 .log_err();
292 if let Some(metadata) = metadata {
293 this.update(cx, |this, _| {
294 for watcher in this.history_watchers.iter_mut() {
295 watcher
296 .unbounded_send(metadata.clone().to_acp(NATIVE_AGENT_SERVER_NAME))
297 .log_err();
298 }
299 })
300 .ok();
301 }
302 });
303 }
304
305 pub fn models(&self) -> &LanguageModels {
306 &self.models
307 }
308
309 async fn maintain_project_context(
310 this: WeakEntity<Self>,
311 mut needs_refresh: watch::Receiver<()>,
312 cx: &mut AsyncApp,
313 ) -> Result<()> {
314 while needs_refresh.changed().await.is_ok() {
315 let project_context = this
316 .update(cx, |this, cx| {
317 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
318 })?
319 .await;
320 this.update(cx, |this, _| this.project_context.replace(project_context))?;
321 }
322
323 Ok(())
324 }
325
326 fn build_project_context(
327 project: &Entity<Project>,
328 prompt_store: Option<&Entity<PromptStore>>,
329 cx: &mut App,
330 ) -> Task<ProjectContext> {
331 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
332 let worktree_tasks = worktrees
333 .into_iter()
334 .map(|worktree| {
335 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
336 })
337 .collect::<Vec<_>>();
338 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
339 prompt_store.read_with(cx, |prompt_store, cx| {
340 let prompts = prompt_store.default_prompt_metadata();
341 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
342 let contents = prompt_store.load(prompt_metadata.id, cx);
343 async move { (contents.await, prompt_metadata) }
344 });
345 cx.background_spawn(future::join_all(load_tasks))
346 })
347 } else {
348 Task::ready(vec![])
349 };
350
351 cx.spawn(async move |_cx| {
352 let (worktrees, default_user_rules) =
353 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
354
355 let worktrees = worktrees
356 .into_iter()
357 .map(|(worktree, _rules_error)| {
358 // TODO: show error message
359 // if let Some(rules_error) = rules_error {
360 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
361 // }
362 worktree
363 })
364 .collect::<Vec<_>>();
365
366 let default_user_rules = default_user_rules
367 .into_iter()
368 .flat_map(|(contents, prompt_metadata)| match contents {
369 Ok(contents) => Some(UserRulesContext {
370 uuid: match prompt_metadata.id {
371 PromptId::User { uuid } => uuid,
372 PromptId::EditWorkflow => return None,
373 },
374 title: prompt_metadata.title.map(|title| title.to_string()),
375 contents,
376 }),
377 Err(_err) => {
378 // TODO: show error message
379 // this.update(cx, |_, cx| {
380 // cx.emit(RulesLoadingError {
381 // message: format!("{err:?}").into(),
382 // });
383 // })
384 // .ok();
385 None
386 }
387 })
388 .collect::<Vec<_>>();
389
390 ProjectContext::new(worktrees, default_user_rules)
391 })
392 }
393
394 fn load_worktree_info_for_system_prompt(
395 worktree: Entity<Worktree>,
396 project: Entity<Project>,
397 cx: &mut App,
398 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
399 let tree = worktree.read(cx);
400 let root_name = tree.root_name().into();
401 let abs_path = tree.abs_path();
402
403 let mut context = WorktreeContext {
404 root_name,
405 abs_path,
406 rules_file: None,
407 };
408
409 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
410 let Some(rules_task) = rules_task else {
411 return Task::ready((context, None));
412 };
413
414 cx.spawn(async move |_| {
415 let (rules_file, rules_file_error) = match rules_task.await {
416 Ok(rules_file) => (Some(rules_file), None),
417 Err(err) => (
418 None,
419 Some(RulesLoadingError {
420 message: format!("{err}").into(),
421 }),
422 ),
423 };
424 context.rules_file = rules_file;
425 (context, rules_file_error)
426 })
427 }
428
429 fn load_worktree_rules_file(
430 worktree: Entity<Worktree>,
431 project: Entity<Project>,
432 cx: &mut App,
433 ) -> Option<Task<Result<RulesFileContext>>> {
434 let worktree = worktree.read(cx);
435 let worktree_id = worktree.id();
436 let selected_rules_file = RULES_FILE_NAMES
437 .into_iter()
438 .filter_map(|name| {
439 worktree
440 .entry_for_path(name)
441 .filter(|entry| entry.is_file())
442 .map(|entry| entry.path.clone())
443 })
444 .next();
445
446 // Note that Cline supports `.clinerules` being a directory, but that is not currently
447 // supported. This doesn't seem to occur often in GitHub repositories.
448 selected_rules_file.map(|path_in_worktree| {
449 let project_path = ProjectPath {
450 worktree_id,
451 path: path_in_worktree.clone(),
452 };
453 let buffer_task =
454 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
455 let rope_task = cx.spawn(async move |cx| {
456 buffer_task.await?.read_with(cx, |buffer, cx| {
457 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
458 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
459 })?
460 });
461 // Build a string from the rope on a background thread.
462 cx.background_spawn(async move {
463 let (project_entry_id, rope) = rope_task.await?;
464 anyhow::Ok(RulesFileContext {
465 path_in_worktree,
466 text: rope.to_string().trim().to_string(),
467 project_entry_id: project_entry_id.to_usize(),
468 })
469 })
470 })
471 }
472
473 fn handle_project_event(
474 &mut self,
475 _project: Entity<Project>,
476 event: &project::Event,
477 _cx: &mut Context<Self>,
478 ) {
479 match event {
480 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
481 self.project_context_needs_refresh.send(()).ok();
482 }
483 project::Event::WorktreeUpdatedEntries(_, items) => {
484 if items.iter().any(|(path, _, _)| {
485 RULES_FILE_NAMES
486 .iter()
487 .any(|name| path.as_ref() == Path::new(name))
488 }) {
489 self.project_context_needs_refresh.send(()).ok();
490 }
491 }
492 _ => {}
493 }
494 }
495
496 fn handle_prompts_updated_event(
497 &mut self,
498 _prompt_store: Entity<PromptStore>,
499 _event: &prompt_store::PromptsUpdatedEvent,
500 _cx: &mut Context<Self>,
501 ) {
502 self.project_context_needs_refresh.send(()).ok();
503 }
504
505 fn handle_models_updated_event(
506 &mut self,
507 registry: Entity<LanguageModelRegistry>,
508 _event: &language_model::Event,
509 cx: &mut Context<Self>,
510 ) {
511 self.models.refresh_list(cx);
512
513 let default_model = LanguageModelRegistry::read_global(cx)
514 .default_model()
515 .map(|m| m.model.clone());
516
517 for session in self.sessions.values_mut() {
518 session.thread.update(cx, |thread, cx| {
519 if thread.model().is_none()
520 && let Some(model) = default_model.clone()
521 {
522 thread.set_model(model, cx);
523 cx.notify();
524 }
525 let summarization_model = registry
526 .read(cx)
527 .thread_summary_model()
528 .map(|model| model.model.clone());
529 thread.set_summarization_model(summarization_model, cx);
530 });
531 }
532 }
533}
534
535/// Wrapper struct that implements the AgentConnection trait
536#[derive(Clone)]
537pub struct NativeAgentConnection(pub Entity<NativeAgent>);
538
539impl NativeAgentConnection {
540 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
541 self.0
542 .read(cx)
543 .sessions
544 .get(session_id)
545 .map(|session| session.thread.clone())
546 }
547
548 fn run_turn(
549 &self,
550 session_id: acp::SessionId,
551 cx: &mut App,
552 f: impl 'static
553 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
554 ) -> Task<Result<acp::PromptResponse>> {
555 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
556 agent
557 .sessions
558 .get_mut(&session_id)
559 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
560 }) else {
561 return Task::ready(Err(anyhow!("Session not found")));
562 };
563 log::debug!("Found session for: {}", session_id);
564
565 let response_stream = match f(thread, cx) {
566 Ok(stream) => stream,
567 Err(err) => return Task::ready(Err(err)),
568 };
569 Self::handle_thread_events(response_stream, acp_thread, cx)
570 }
571
572 fn handle_thread_events(
573 mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
574 acp_thread: WeakEntity<AcpThread>,
575 cx: &mut App,
576 ) -> Task<Result<acp::PromptResponse>> {
577 cx.spawn(async move |cx| {
578 // Handle response stream and forward to session.acp_thread
579 while let Some(result) = response_stream.next().await {
580 match result {
581 Ok(event) => {
582 log::trace!("Received completion event: {:?}", event);
583
584 match event {
585 ThreadEvent::UserMessage(message) => {
586 acp_thread.update(cx, |thread, cx| {
587 for content in message.content {
588 thread.push_user_content_block(
589 Some(message.id.clone()),
590 content.into(),
591 cx,
592 );
593 }
594 })?;
595 }
596 ThreadEvent::AgentText(text) => {
597 acp_thread.update(cx, |thread, cx| {
598 thread.push_assistant_content_block(
599 acp::ContentBlock::Text(acp::TextContent {
600 text,
601 annotations: None,
602 }),
603 false,
604 cx,
605 )
606 })?;
607 }
608 ThreadEvent::AgentThinking(text) => {
609 acp_thread.update(cx, |thread, cx| {
610 thread.push_assistant_content_block(
611 acp::ContentBlock::Text(acp::TextContent {
612 text,
613 annotations: None,
614 }),
615 true,
616 cx,
617 )
618 })?;
619 }
620 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
621 tool_call,
622 options,
623 response,
624 }) => {
625 let recv = acp_thread.update(cx, |thread, cx| {
626 thread.request_tool_call_authorization(tool_call, options, cx)
627 })?;
628 cx.background_spawn(async move {
629 if let Some(recv) = recv.log_err()
630 && let Some(option) = recv
631 .await
632 .context("authorization sender was dropped")
633 .log_err()
634 {
635 response
636 .send(option)
637 .map(|_| anyhow!("authorization receiver was dropped"))
638 .log_err();
639 }
640 })
641 .detach();
642 }
643 ThreadEvent::ToolCall(tool_call) => {
644 acp_thread.update(cx, |thread, cx| {
645 thread.upsert_tool_call(tool_call, cx)
646 })??;
647 }
648 ThreadEvent::ToolCallUpdate(update) => {
649 acp_thread.update(cx, |thread, cx| {
650 thread.update_tool_call(update, cx)
651 })??;
652 }
653 ThreadEvent::TitleUpdate(title) => {
654 acp_thread
655 .update(cx, |thread, cx| thread.update_title(title, cx))??;
656 }
657 ThreadEvent::Stop(stop_reason) => {
658 log::debug!("Assistant message complete: {:?}", stop_reason);
659 return Ok(acp::PromptResponse { stop_reason });
660 }
661 }
662 }
663 Err(e) => {
664 log::error!("Error in model response stream: {:?}", e);
665 return Err(e);
666 }
667 }
668 }
669
670 log::info!("Response stream completed");
671 anyhow::Ok(acp::PromptResponse {
672 stop_reason: acp::StopReason::EndTurn,
673 })
674 })
675 }
676
677 fn register_tools(
678 thread: &mut Thread,
679 project: Entity<Project>,
680 action_log: Entity<action_log::ActionLog>,
681 cx: &mut Context<Thread>,
682 ) {
683 let language_registry = project.read(cx).languages().clone();
684 thread.add_tool(CopyPathTool::new(project.clone()));
685 thread.add_tool(CreateDirectoryTool::new(project.clone()));
686 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
687 thread.add_tool(DiagnosticsTool::new(project.clone()));
688 thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
689 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
690 thread.add_tool(FindPathTool::new(project.clone()));
691 thread.add_tool(GrepTool::new(project.clone()));
692 thread.add_tool(ListDirectoryTool::new(project.clone()));
693 thread.add_tool(MovePathTool::new(project.clone()));
694 thread.add_tool(NowTool);
695 thread.add_tool(OpenTool::new(project.clone()));
696 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
697 thread.add_tool(TerminalTool::new(project.clone(), cx));
698 thread.add_tool(ThinkingTool);
699 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
700 }
701}
702
703impl AgentModelSelector for NativeAgentConnection {
704 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
705 log::debug!("NativeAgentConnection::list_models called");
706 let list = self.0.read(cx).models.model_list.clone();
707 Task::ready(if list.is_empty() {
708 Err(anyhow::anyhow!("No models available"))
709 } else {
710 Ok(list)
711 })
712 }
713
714 fn select_model(
715 &self,
716 session_id: acp::SessionId,
717 model_id: acp_thread::AgentModelId,
718 cx: &mut App,
719 ) -> Task<Result<()>> {
720 log::info!("Setting model for session {}: {}", session_id, model_id);
721 let Some(thread) = self
722 .0
723 .read(cx)
724 .sessions
725 .get(&session_id)
726 .map(|session| session.thread.clone())
727 else {
728 return Task::ready(Err(anyhow!("Session not found")));
729 };
730
731 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
732 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
733 };
734
735 thread.update(cx, |thread, cx| {
736 thread.set_model(model.clone(), cx);
737 });
738
739 update_settings_file::<AgentSettings>(
740 self.0.read(cx).fs.clone(),
741 cx,
742 move |settings, _cx| {
743 settings.set_model(model);
744 },
745 );
746
747 Task::ready(Ok(()))
748 }
749
750 fn selected_model(
751 &self,
752 session_id: &acp::SessionId,
753 cx: &mut App,
754 ) -> Task<Result<acp_thread::AgentModelInfo>> {
755 let session_id = session_id.clone();
756
757 let Some(thread) = self
758 .0
759 .read(cx)
760 .sessions
761 .get(&session_id)
762 .map(|session| session.thread.clone())
763 else {
764 return Task::ready(Err(anyhow!("Session not found")));
765 };
766 let Some(model) = thread.read(cx).model() else {
767 return Task::ready(Err(anyhow!("Model not found")));
768 };
769 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
770 else {
771 return Task::ready(Err(anyhow!("Provider not found")));
772 };
773 Task::ready(Ok(LanguageModels::map_language_model_to_info(
774 model, &provider,
775 )))
776 }
777
778 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
779 self.0.read(cx).models.watch()
780 }
781}
782
783impl acp_thread::AgentConnection for NativeAgentConnection {
784 fn new_thread(
785 self: Rc<Self>,
786 project: Entity<Project>,
787 cwd: &Path,
788 cx: &mut App,
789 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
790 let agent = self.0.clone();
791 log::info!("Creating new thread for project at: {:?}", cwd);
792
793 cx.spawn(async move |cx| {
794 log::debug!("Starting thread creation in async context");
795
796 // Generate session ID
797 let session_id = generate_session_id();
798 log::info!("Created session with ID: {}", session_id);
799
800 // Create AcpThread
801 let acp_thread = cx.update(|cx| {
802 cx.new(|cx| {
803 acp_thread::AcpThread::new(
804 "agent2",
805 self.clone(),
806 project.clone(),
807 session_id.clone(),
808 cx,
809 )
810 })
811 })?;
812 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
813
814 // Create Thread
815 let thread = agent.update(
816 cx,
817 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
818 // Fetch default model from registry settings
819 let registry = LanguageModelRegistry::read_global(cx);
820
821 // Log available models for debugging
822 let available_count = registry.available_models(cx).count();
823 log::debug!("Total available models: {}", available_count);
824
825 let default_model = registry.default_model().and_then(|default_model| {
826 agent
827 .models
828 .model_from_id(&LanguageModels::model_id(&default_model.model))
829 });
830
831 let summarization_model = registry.thread_summary_model().map(|c| c.model);
832
833 let thread = cx.new(|cx| {
834 let mut thread = Thread::new(
835 session_id.clone(),
836 project.clone(),
837 agent.project_context.clone(),
838 agent.context_server_registry.clone(),
839 action_log.clone(),
840 agent.templates.clone(),
841 default_model,
842 summarization_model,
843 cx,
844 );
845 Self::register_tools(&mut thread, project, action_log, cx);
846 thread
847 });
848
849 Ok(thread)
850 },
851 )??;
852
853 // Store the session
854 agent.update(cx, |agent, cx| {
855 agent.insert_session(thread, acp_thread.clone(), cx)
856 })?;
857
858 Ok(acp_thread)
859 })
860 }
861
862 fn auth_methods(&self) -> &[acp::AuthMethod] {
863 &[] // No auth for in-process
864 }
865
866 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
867 Task::ready(Ok(()))
868 }
869
870 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
871 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
872 }
873
874 fn prompt(
875 &self,
876 id: Option<acp_thread::UserMessageId>,
877 params: acp::PromptRequest,
878 cx: &mut App,
879 ) -> Task<Result<acp::PromptResponse>> {
880 let id = id.expect("UserMessageId is required");
881 let session_id = params.session_id.clone();
882 log::info!("Received prompt request for session: {}", session_id);
883 log::debug!("Prompt blocks count: {}", params.prompt.len());
884
885 self.run_turn(session_id, cx, |thread, cx| {
886 let content: Vec<UserMessageContent> = params
887 .prompt
888 .into_iter()
889 .map(Into::into)
890 .collect::<Vec<_>>();
891 log::info!("Converted prompt to message: {} chars", content.len());
892 log::debug!("Message id: {:?}", id);
893 log::debug!("Message content: {:?}", content);
894
895 thread.update(cx, |thread, cx| thread.send(id, content, cx))
896 })
897 }
898
899 fn resume(
900 &self,
901 session_id: &acp::SessionId,
902 _cx: &mut App,
903 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
904 Some(Rc::new(NativeAgentSessionResume {
905 connection: self.clone(),
906 session_id: session_id.clone(),
907 }) as _)
908 }
909
910 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
911 log::info!("Cancelling on session: {}", session_id);
912 self.0.update(cx, |agent, cx| {
913 if let Some(agent) = agent.sessions.get(session_id) {
914 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
915 }
916 });
917 }
918
919 fn session_editor(
920 &self,
921 session_id: &agent_client_protocol::SessionId,
922 cx: &mut App,
923 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
924 self.0.update(cx, |agent, _cx| {
925 agent
926 .sessions
927 .get(session_id)
928 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
929 })
930 }
931
932 fn history(self: Rc<Self>) -> Option<Rc<dyn AgentHistory>> {
933 Some(self)
934 }
935
936 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
937 self
938 }
939}
940
941struct NativeAgentSessionEditor(Entity<Thread>);
942
943impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
944 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
945 Task::ready(
946 self.0
947 .update(cx, |thread, cx| thread.truncate(message_id, cx)),
948 )
949 }
950}
951
952impl acp_thread::AgentHistory for NativeAgentConnection {
953 fn list_threads(&self, cx: &mut App) -> Task<Result<Vec<AcpThreadMetadata>>> {
954 let database = self.0.read(cx).thread_database.clone();
955 cx.background_executor().spawn(async move {
956 let threads = database.list_threads().await?;
957 anyhow::Ok(
958 threads
959 .into_iter()
960 .map(|thread| thread.to_acp(NATIVE_AGENT_SERVER_NAME))
961 .collect::<Vec<_>>(),
962 )
963 })
964 }
965
966 fn observe_history(&self, cx: &mut App) -> mpsc::UnboundedReceiver<AcpThreadMetadata> {
967 let (tx, rx) = mpsc::unbounded();
968 self.0.update(cx, |this, _| this.history_watchers.push(tx));
969 rx
970 }
971
972 fn load_thread(
973 self: Rc<Self>,
974 project: Entity<Project>,
975 _cwd: &Path,
976 session_id: acp::SessionId,
977 cx: &mut App,
978 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
979 let database = self.0.update(cx, |this, _| this.thread_database.clone());
980 cx.spawn(async move |cx| {
981 let db_thread = database
982 .load_thread(session_id.clone())
983 .await?
984 .context("no such thread found")?;
985
986 let acp_thread = cx.update(|cx| {
987 cx.new(|cx| {
988 acp_thread::AcpThread::new(
989 db_thread.title.clone(),
990 self.clone(),
991 project.clone(),
992 session_id.clone(),
993 cx,
994 )
995 })
996 })?;
997 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
998 let agent = self.0.clone();
999
1000 // Create Thread
1001 let thread = agent.update(cx, |agent, cx| {
1002 let language_model_registry = LanguageModelRegistry::global(cx);
1003 let configured_model = language_model_registry
1004 .update(cx, |registry, cx| {
1005 db_thread
1006 .model
1007 .as_ref()
1008 .and_then(|model| {
1009 let model = SelectedModel {
1010 provider: model.provider.clone().into(),
1011 model: model.model.clone().into(),
1012 };
1013 registry.select_model(&model, cx)
1014 })
1015 .or_else(|| registry.default_model())
1016 })
1017 .context("no default model configured")?;
1018
1019 let model = agent
1020 .models
1021 .model_from_id(&LanguageModels::model_id(&configured_model.model))
1022 .context("no model by id")?;
1023
1024 let summarization_model = language_model_registry
1025 .read(cx)
1026 .thread_summary_model()
1027 .map(|c| c.model);
1028
1029 let thread = cx.new(|cx| {
1030 let mut thread = Thread::from_db(
1031 session_id,
1032 db_thread,
1033 project.clone(),
1034 agent.project_context.clone(),
1035 agent.context_server_registry.clone(),
1036 action_log.clone(),
1037 agent.templates.clone(),
1038 model,
1039 summarization_model,
1040 cx,
1041 );
1042 Self::register_tools(&mut thread, project, action_log, cx);
1043 thread
1044 });
1045
1046 anyhow::Ok(thread)
1047 })??;
1048
1049 // Store the session
1050 agent.update(cx, |agent, cx| {
1051 agent.insert_session(thread.clone(), acp_thread.clone(), cx)
1052 })?;
1053
1054 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
1055 cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
1056 .await?;
1057
1058 Ok(acp_thread)
1059 })
1060 }
1061}
1062
1063struct NativeAgentSessionResume {
1064 connection: NativeAgentConnection,
1065 session_id: acp::SessionId,
1066}
1067
1068impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1069 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1070 self.connection
1071 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1072 thread.update(cx, |thread, cx| thread.resume(cx))
1073 })
1074 }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079 use crate::HistoryStore;
1080
1081 use super::*;
1082 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
1083 use fs::FakeFs;
1084 use gpui::TestAppContext;
1085 use language_model::fake_provider::FakeLanguageModel;
1086 use serde_json::json;
1087 use settings::SettingsStore;
1088 use util::path;
1089
1090 #[gpui::test]
1091 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1092 init_test(cx);
1093 let fs = FakeFs::new(cx.executor());
1094 fs.insert_tree(
1095 "/",
1096 json!({
1097 "a": {}
1098 }),
1099 )
1100 .await;
1101 let project = Project::test(fs.clone(), [], cx).await;
1102 let agent = NativeAgent::new(
1103 project.clone(),
1104 Templates::new(),
1105 None,
1106 fs.clone(),
1107 &mut cx.to_async(),
1108 )
1109 .await
1110 .unwrap();
1111 agent.read_with(cx, |agent, _| {
1112 assert_eq!(agent.project_context.borrow().worktrees, vec![])
1113 });
1114
1115 let worktree = project
1116 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1117 .await
1118 .unwrap();
1119 cx.run_until_parked();
1120 agent.read_with(cx, |agent, _| {
1121 assert_eq!(
1122 agent.project_context.borrow().worktrees,
1123 vec![WorktreeContext {
1124 root_name: "a".into(),
1125 abs_path: Path::new("/a").into(),
1126 rules_file: None
1127 }]
1128 )
1129 });
1130
1131 // Creating `/a/.rules` updates the project context.
1132 fs.insert_file("/a/.rules", Vec::new()).await;
1133 cx.run_until_parked();
1134 agent.read_with(cx, |agent, cx| {
1135 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1136 assert_eq!(
1137 agent.project_context.borrow().worktrees,
1138 vec![WorktreeContext {
1139 root_name: "a".into(),
1140 abs_path: Path::new("/a").into(),
1141 rules_file: Some(RulesFileContext {
1142 path_in_worktree: Path::new(".rules").into(),
1143 text: "".into(),
1144 project_entry_id: rules_entry.id.to_usize()
1145 })
1146 }]
1147 )
1148 });
1149 }
1150
1151 #[gpui::test]
1152 async fn test_listing_models(cx: &mut TestAppContext) {
1153 init_test(cx);
1154 let fs = FakeFs::new(cx.executor());
1155 fs.insert_tree("/", json!({ "a": {} })).await;
1156 let project = Project::test(fs.clone(), [], cx).await;
1157 let connection = NativeAgentConnection(
1158 NativeAgent::new(
1159 project.clone(),
1160 Templates::new(),
1161 None,
1162 fs.clone(),
1163 &mut cx.to_async(),
1164 )
1165 .await
1166 .unwrap(),
1167 );
1168
1169 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1170
1171 let acp_thread::AgentModelList::Grouped(models) = models else {
1172 panic!("Unexpected model group");
1173 };
1174 assert_eq!(
1175 models,
1176 IndexMap::from_iter([(
1177 AgentModelGroupName("Fake".into()),
1178 vec![AgentModelInfo {
1179 id: AgentModelId("fake/fake".into()),
1180 name: "Fake".into(),
1181 icon: Some(ui::IconName::ZedAssistant),
1182 }]
1183 )])
1184 );
1185 }
1186
1187 #[gpui::test]
1188 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1189 init_test(cx);
1190 let fs = FakeFs::new(cx.executor());
1191 fs.create_dir(paths::settings_file().parent().unwrap())
1192 .await
1193 .unwrap();
1194 fs.insert_file(
1195 paths::settings_file(),
1196 json!({
1197 "agent": {
1198 "default_model": {
1199 "provider": "foo",
1200 "model": "bar"
1201 }
1202 }
1203 })
1204 .to_string()
1205 .into_bytes(),
1206 )
1207 .await;
1208 let project = Project::test(fs.clone(), [], cx).await;
1209
1210 // Create the agent and connection
1211 let agent = NativeAgent::new(
1212 project.clone(),
1213 Templates::new(),
1214 None,
1215 fs.clone(),
1216 &mut cx.to_async(),
1217 )
1218 .await
1219 .unwrap();
1220 let connection = NativeAgentConnection(agent.clone());
1221
1222 // Create a thread/session
1223 let acp_thread = cx
1224 .update(|cx| {
1225 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1226 })
1227 .await
1228 .unwrap();
1229
1230 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1231
1232 // Select a model
1233 let model_id = AgentModelId("fake/fake".into());
1234 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1235 .await
1236 .unwrap();
1237
1238 // Verify the thread has the selected model
1239 agent.read_with(cx, |agent, _| {
1240 let session = agent.sessions.get(&session_id).unwrap();
1241 session.thread.read_with(cx, |thread, _| {
1242 assert_eq!(thread.model().unwrap().id().0, "fake");
1243 });
1244 });
1245
1246 cx.run_until_parked();
1247
1248 // Verify settings file was updated
1249 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1250 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1251
1252 // Check that the agent settings contain the selected model
1253 assert_eq!(
1254 settings_json["agent"]["default_model"]["model"],
1255 json!("fake")
1256 );
1257 assert_eq!(
1258 settings_json["agent"]["default_model"]["provider"],
1259 json!("fake")
1260 );
1261 }
1262
1263 #[gpui::test]
1264 async fn test_history(cx: &mut TestAppContext) {
1265 init_test(cx);
1266 let fs = FakeFs::new(cx.executor());
1267 let project = Project::test(fs.clone(), [], cx).await;
1268
1269 let agent = NativeAgent::new(
1270 project.clone(),
1271 Templates::new(),
1272 None,
1273 fs.clone(),
1274 &mut cx.to_async(),
1275 )
1276 .await
1277 .unwrap();
1278 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1279 let history = connection.clone().history().unwrap();
1280 let history_store = cx.new(|cx| HistoryStore::get_or_init(cx));
1281
1282 history_store
1283 .update(cx, |history_store, cx| {
1284 history_store.load_history(NATIVE_AGENT_SERVER_NAME.clone(), history.as_ref(), cx)
1285 })
1286 .await
1287 .unwrap();
1288
1289 let acp_thread = cx
1290 .update(|cx| {
1291 connection
1292 .clone()
1293 .new_thread(project.clone(), Path::new(path!("")), cx)
1294 })
1295 .await
1296 .unwrap();
1297 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1298 let selector = connection.model_selector().unwrap();
1299
1300 let summarization_model: Arc<dyn LanguageModel> =
1301 Arc::new(FakeLanguageModel::default()) as _;
1302
1303 agent.update(cx, |agent, cx| {
1304 let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
1305 thread.update(cx, |thread, cx| {
1306 thread.set_summarization_model(Some(summarization_model.clone()), cx);
1307 })
1308 });
1309
1310 let model = cx
1311 .update(|cx| selector.selected_model(&session_id, cx))
1312 .await
1313 .expect("selected_model should succeed");
1314 let model = cx
1315 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1316 .unwrap();
1317 let model = model.as_fake();
1318
1319 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx));
1320 let send = cx.foreground_executor().spawn(send);
1321 cx.run_until_parked();
1322 model.send_last_completion_stream_text_chunk("Hey");
1323 model.end_last_completion_stream();
1324 send.await.unwrap();
1325
1326 summarization_model
1327 .as_fake()
1328 .send_last_completion_stream_text_chunk("Saying Hello");
1329 summarization_model.as_fake().end_last_completion_stream();
1330 cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
1331
1332 let history = history_store.update(cx, |store, cx| store.entries(cx));
1333 assert_eq!(history.len(), 1);
1334 assert_eq!(history[0].title(), "Saying Hello");
1335 }
1336
1337 fn init_test(cx: &mut TestAppContext) {
1338 env_logger::try_init().ok();
1339 cx.update(|cx| {
1340 let settings_store = SettingsStore::test(cx);
1341 cx.set_global(settings_store);
1342 Project::init_settings(cx);
1343 agent_settings::init(cx);
1344 language::init(cx);
1345 LanguageModelRegistry::test(cx);
1346 });
1347 }
1348}