1use crate::{
2 AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
3 DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
4 MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
5 ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
6};
7use acp_thread::AgentModelSelector;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::cell::RefCell;
26use std::collections::HashMap;
27use std::path::Path;
28use std::rc::Rc;
29use std::sync::Arc;
30use util::ResultExt;
31
32const RULES_FILE_NAMES: [&'static str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 _subscription: Subscription,
55}
56
57pub struct LanguageModels {
58 /// Access language model by ID
59 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
60 /// Cached list for returning language model information
61 model_list: acp_thread::AgentModelList,
62 refresh_models_rx: watch::Receiver<()>,
63 refresh_models_tx: watch::Sender<()>,
64}
65
66impl LanguageModels {
67 fn new(cx: &App) -> Self {
68 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
69 let mut this = Self {
70 models: HashMap::default(),
71 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
72 refresh_models_rx,
73 refresh_models_tx,
74 };
75 this.refresh_list(cx);
76 this
77 }
78
79 fn refresh_list(&mut self, cx: &App) {
80 let providers = LanguageModelRegistry::global(cx)
81 .read(cx)
82 .providers()
83 .into_iter()
84 .filter(|provider| provider.is_authenticated(cx))
85 .collect::<Vec<_>>();
86
87 let mut language_model_list = IndexMap::default();
88 let mut recommended_models = HashSet::default();
89
90 let mut recommended = Vec::new();
91 for provider in &providers {
92 for model in provider.recommended_models(cx) {
93 recommended_models.insert(model.id());
94 recommended.push(Self::map_language_model_to_info(&model, &provider));
95 }
96 }
97 if !recommended.is_empty() {
98 language_model_list.insert(
99 acp_thread::AgentModelGroupName("Recommended".into()),
100 recommended,
101 );
102 }
103
104 let mut models = HashMap::default();
105 for provider in providers {
106 let mut provider_models = Vec::new();
107 for model in provider.provided_models(cx) {
108 let model_info = Self::map_language_model_to_info(&model, &provider);
109 let model_id = model_info.id.clone();
110 if !recommended_models.contains(&model.id()) {
111 provider_models.push(model_info);
112 }
113 models.insert(model_id, model);
114 }
115 if !provider_models.is_empty() {
116 language_model_list.insert(
117 acp_thread::AgentModelGroupName(provider.name().0.clone()),
118 provider_models,
119 );
120 }
121 }
122
123 self.models = models;
124 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
125 self.refresh_models_tx.send(()).ok();
126 }
127
128 fn watch(&self) -> watch::Receiver<()> {
129 self.refresh_models_rx.clone()
130 }
131
132 pub fn model_from_id(
133 &self,
134 model_id: &acp_thread::AgentModelId,
135 ) -> Option<Arc<dyn LanguageModel>> {
136 self.models.get(model_id).cloned()
137 }
138
139 fn map_language_model_to_info(
140 model: &Arc<dyn LanguageModel>,
141 provider: &Arc<dyn LanguageModelProvider>,
142 ) -> acp_thread::AgentModelInfo {
143 acp_thread::AgentModelInfo {
144 id: Self::model_id(model),
145 name: model.name().0,
146 icon: Some(provider.icon()),
147 }
148 }
149
150 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
151 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
152 }
153}
154
155pub struct NativeAgent {
156 /// Session ID -> Session mapping
157 sessions: HashMap<acp::SessionId, Session>,
158 /// Shared project context for all threads
159 project_context: Rc<RefCell<ProjectContext>>,
160 project_context_needs_refresh: watch::Sender<()>,
161 _maintain_project_context: Task<Result<()>>,
162 context_server_registry: Entity<ContextServerRegistry>,
163 /// Shared templates for all threads
164 templates: Arc<Templates>,
165 /// Cached model information
166 models: LanguageModels,
167 project: Entity<Project>,
168 prompt_store: Option<Entity<PromptStore>>,
169 fs: Arc<dyn Fs>,
170 _subscriptions: Vec<Subscription>,
171}
172
173impl NativeAgent {
174 pub async fn new(
175 project: Entity<Project>,
176 templates: Arc<Templates>,
177 prompt_store: Option<Entity<PromptStore>>,
178 fs: Arc<dyn Fs>,
179 cx: &mut AsyncApp,
180 ) -> Result<Entity<NativeAgent>> {
181 log::info!("Creating new NativeAgent");
182
183 let project_context = cx
184 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
185 .await;
186
187 cx.new(|cx| {
188 let mut subscriptions = vec![
189 cx.subscribe(&project, Self::handle_project_event),
190 cx.subscribe(
191 &LanguageModelRegistry::global(cx),
192 Self::handle_models_updated_event,
193 ),
194 ];
195 if let Some(prompt_store) = prompt_store.as_ref() {
196 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
197 }
198
199 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
200 watch::channel(());
201 Self {
202 sessions: HashMap::new(),
203 project_context: Rc::new(RefCell::new(project_context)),
204 project_context_needs_refresh: project_context_needs_refresh_tx,
205 _maintain_project_context: cx.spawn(async move |this, cx| {
206 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
207 }),
208 context_server_registry: cx.new(|cx| {
209 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
210 }),
211 templates,
212 models: LanguageModels::new(cx),
213 project,
214 prompt_store,
215 fs,
216 _subscriptions: subscriptions,
217 }
218 })
219 }
220
221 pub fn models(&self) -> &LanguageModels {
222 &self.models
223 }
224
225 async fn maintain_project_context(
226 this: WeakEntity<Self>,
227 mut needs_refresh: watch::Receiver<()>,
228 cx: &mut AsyncApp,
229 ) -> Result<()> {
230 while needs_refresh.changed().await.is_ok() {
231 let project_context = this
232 .update(cx, |this, cx| {
233 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
234 })?
235 .await;
236 this.update(cx, |this, _| this.project_context.replace(project_context))?;
237 }
238
239 Ok(())
240 }
241
242 fn build_project_context(
243 project: &Entity<Project>,
244 prompt_store: Option<&Entity<PromptStore>>,
245 cx: &mut App,
246 ) -> Task<ProjectContext> {
247 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
248 let worktree_tasks = worktrees
249 .into_iter()
250 .map(|worktree| {
251 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
252 })
253 .collect::<Vec<_>>();
254 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
255 prompt_store.read_with(cx, |prompt_store, cx| {
256 let prompts = prompt_store.default_prompt_metadata();
257 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
258 let contents = prompt_store.load(prompt_metadata.id, cx);
259 async move { (contents.await, prompt_metadata) }
260 });
261 cx.background_spawn(future::join_all(load_tasks))
262 })
263 } else {
264 Task::ready(vec![])
265 };
266
267 cx.spawn(async move |_cx| {
268 let (worktrees, default_user_rules) =
269 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
270
271 let worktrees = worktrees
272 .into_iter()
273 .map(|(worktree, _rules_error)| {
274 // TODO: show error message
275 // if let Some(rules_error) = rules_error {
276 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
277 // }
278 worktree
279 })
280 .collect::<Vec<_>>();
281
282 let default_user_rules = default_user_rules
283 .into_iter()
284 .flat_map(|(contents, prompt_metadata)| match contents {
285 Ok(contents) => Some(UserRulesContext {
286 uuid: match prompt_metadata.id {
287 PromptId::User { uuid } => uuid,
288 PromptId::EditWorkflow => return None,
289 },
290 title: prompt_metadata.title.map(|title| title.to_string()),
291 contents,
292 }),
293 Err(_err) => {
294 // TODO: show error message
295 // this.update(cx, |_, cx| {
296 // cx.emit(RulesLoadingError {
297 // message: format!("{err:?}").into(),
298 // });
299 // })
300 // .ok();
301 None
302 }
303 })
304 .collect::<Vec<_>>();
305
306 ProjectContext::new(worktrees, default_user_rules)
307 })
308 }
309
310 fn load_worktree_info_for_system_prompt(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
315 let tree = worktree.read(cx);
316 let root_name = tree.root_name().into();
317 let abs_path = tree.abs_path();
318
319 let mut context = WorktreeContext {
320 root_name,
321 abs_path,
322 rules_file: None,
323 };
324
325 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
326 let Some(rules_task) = rules_task else {
327 return Task::ready((context, None));
328 };
329
330 cx.spawn(async move |_| {
331 let (rules_file, rules_file_error) = match rules_task.await {
332 Ok(rules_file) => (Some(rules_file), None),
333 Err(err) => (
334 None,
335 Some(RulesLoadingError {
336 message: format!("{err}").into(),
337 }),
338 ),
339 };
340 context.rules_file = rules_file;
341 (context, rules_file_error)
342 })
343 }
344
345 fn load_worktree_rules_file(
346 worktree: Entity<Worktree>,
347 project: Entity<Project>,
348 cx: &mut App,
349 ) -> Option<Task<Result<RulesFileContext>>> {
350 let worktree = worktree.read(cx);
351 let worktree_id = worktree.id();
352 let selected_rules_file = RULES_FILE_NAMES
353 .into_iter()
354 .filter_map(|name| {
355 worktree
356 .entry_for_path(name)
357 .filter(|entry| entry.is_file())
358 .map(|entry| entry.path.clone())
359 })
360 .next();
361
362 // Note that Cline supports `.clinerules` being a directory, but that is not currently
363 // supported. This doesn't seem to occur often in GitHub repositories.
364 selected_rules_file.map(|path_in_worktree| {
365 let project_path = ProjectPath {
366 worktree_id,
367 path: path_in_worktree.clone(),
368 };
369 let buffer_task =
370 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
371 let rope_task = cx.spawn(async move |cx| {
372 buffer_task.await?.read_with(cx, |buffer, cx| {
373 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
374 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
375 })?
376 });
377 // Build a string from the rope on a background thread.
378 cx.background_spawn(async move {
379 let (project_entry_id, rope) = rope_task.await?;
380 anyhow::Ok(RulesFileContext {
381 path_in_worktree,
382 text: rope.to_string().trim().to_string(),
383 project_entry_id: project_entry_id.to_usize(),
384 })
385 })
386 })
387 }
388
389 fn handle_project_event(
390 &mut self,
391 _project: Entity<Project>,
392 event: &project::Event,
393 _cx: &mut Context<Self>,
394 ) {
395 match event {
396 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
397 self.project_context_needs_refresh.send(()).ok();
398 }
399 project::Event::WorktreeUpdatedEntries(_, items) => {
400 if items.iter().any(|(path, _, _)| {
401 RULES_FILE_NAMES
402 .iter()
403 .any(|name| path.as_ref() == Path::new(name))
404 }) {
405 self.project_context_needs_refresh.send(()).ok();
406 }
407 }
408 _ => {}
409 }
410 }
411
412 fn handle_prompts_updated_event(
413 &mut self,
414 _prompt_store: Entity<PromptStore>,
415 _event: &prompt_store::PromptsUpdatedEvent,
416 _cx: &mut Context<Self>,
417 ) {
418 self.project_context_needs_refresh.send(()).ok();
419 }
420
421 fn handle_models_updated_event(
422 &mut self,
423 _registry: Entity<LanguageModelRegistry>,
424 _event: &language_model::Event,
425 cx: &mut Context<Self>,
426 ) {
427 self.models.refresh_list(cx);
428 for session in self.sessions.values_mut() {
429 session.thread.update(cx, |thread, _| {
430 let model_id = LanguageModels::model_id(&thread.model());
431 if let Some(model) = self.models.model_from_id(&model_id) {
432 thread.set_model(model.clone());
433 }
434 });
435 }
436 }
437}
438
439/// Wrapper struct that implements the AgentConnection trait
440#[derive(Clone)]
441pub struct NativeAgentConnection(pub Entity<NativeAgent>);
442
443impl NativeAgentConnection {
444 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
445 self.0
446 .read(cx)
447 .sessions
448 .get(session_id)
449 .map(|session| session.thread.clone())
450 }
451
452 fn run_turn(
453 &self,
454 session_id: acp::SessionId,
455 cx: &mut App,
456 f: impl 'static
457 + FnOnce(
458 Entity<Thread>,
459 &mut App,
460 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
461 ) -> Task<Result<acp::PromptResponse>> {
462 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
463 agent
464 .sessions
465 .get_mut(&session_id)
466 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
467 }) else {
468 return Task::ready(Err(anyhow!("Session not found")));
469 };
470 log::debug!("Found session for: {}", session_id);
471
472 let mut response_stream = match f(thread, cx) {
473 Ok(stream) => stream,
474 Err(err) => return Task::ready(Err(err)),
475 };
476 cx.spawn(async move |cx| {
477 // Handle response stream and forward to session.acp_thread
478 while let Some(result) = response_stream.next().await {
479 match result {
480 Ok(event) => {
481 log::trace!("Received completion event: {:?}", event);
482
483 match event {
484 AgentResponseEvent::Text(text) => {
485 acp_thread.update(cx, |thread, cx| {
486 thread.push_assistant_content_block(
487 acp::ContentBlock::Text(acp::TextContent {
488 text,
489 annotations: None,
490 }),
491 false,
492 cx,
493 )
494 })?;
495 }
496 AgentResponseEvent::Thinking(text) => {
497 acp_thread.update(cx, |thread, cx| {
498 thread.push_assistant_content_block(
499 acp::ContentBlock::Text(acp::TextContent {
500 text,
501 annotations: None,
502 }),
503 true,
504 cx,
505 )
506 })?;
507 }
508 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
509 tool_call,
510 options,
511 response,
512 }) => {
513 let recv = acp_thread.update(cx, |thread, cx| {
514 thread.request_tool_call_authorization(tool_call, options, cx)
515 })?;
516 cx.background_spawn(async move {
517 if let Some(option) = recv
518 .await
519 .context("authorization sender was dropped")
520 .log_err()
521 {
522 response
523 .send(option)
524 .map(|_| anyhow!("authorization receiver was dropped"))
525 .log_err();
526 }
527 })
528 .detach();
529 }
530 AgentResponseEvent::ToolCall(tool_call) => {
531 acp_thread.update(cx, |thread, cx| {
532 thread.upsert_tool_call(tool_call, cx)
533 })?;
534 }
535 AgentResponseEvent::ToolCallUpdate(update) => {
536 acp_thread.update(cx, |thread, cx| {
537 thread.update_tool_call(update, cx)
538 })??;
539 }
540 AgentResponseEvent::Stop(stop_reason) => {
541 log::debug!("Assistant message complete: {:?}", stop_reason);
542 return Ok(acp::PromptResponse { stop_reason });
543 }
544 }
545 }
546 Err(e) => {
547 log::error!("Error in model response stream: {:?}", e);
548 return Err(e);
549 }
550 }
551 }
552
553 log::info!("Response stream completed");
554 anyhow::Ok(acp::PromptResponse {
555 stop_reason: acp::StopReason::EndTurn,
556 })
557 })
558 }
559}
560
561impl AgentModelSelector for NativeAgentConnection {
562 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
563 log::debug!("NativeAgentConnection::list_models called");
564 let list = self.0.read(cx).models.model_list.clone();
565 Task::ready(if list.is_empty() {
566 Err(anyhow::anyhow!("No models available"))
567 } else {
568 Ok(list)
569 })
570 }
571
572 fn select_model(
573 &self,
574 session_id: acp::SessionId,
575 model_id: acp_thread::AgentModelId,
576 cx: &mut App,
577 ) -> Task<Result<()>> {
578 log::info!("Setting model for session {}: {}", session_id, model_id);
579 let Some(thread) = self
580 .0
581 .read(cx)
582 .sessions
583 .get(&session_id)
584 .map(|session| session.thread.clone())
585 else {
586 return Task::ready(Err(anyhow!("Session not found")));
587 };
588
589 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
590 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
591 };
592
593 thread.update(cx, |thread, _cx| {
594 thread.set_model(model.clone());
595 });
596
597 update_settings_file::<AgentSettings>(
598 self.0.read(cx).fs.clone(),
599 cx,
600 move |settings, _cx| {
601 settings.set_model(model);
602 },
603 );
604
605 Task::ready(Ok(()))
606 }
607
608 fn selected_model(
609 &self,
610 session_id: &acp::SessionId,
611 cx: &mut App,
612 ) -> Task<Result<acp_thread::AgentModelInfo>> {
613 let session_id = session_id.clone();
614
615 let Some(thread) = self
616 .0
617 .read(cx)
618 .sessions
619 .get(&session_id)
620 .map(|session| session.thread.clone())
621 else {
622 return Task::ready(Err(anyhow!("Session not found")));
623 };
624 let model = thread.read(cx).model().clone();
625 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
626 else {
627 return Task::ready(Err(anyhow!("Provider not found")));
628 };
629 Task::ready(Ok(LanguageModels::map_language_model_to_info(
630 &model, &provider,
631 )))
632 }
633
634 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
635 self.0.read(cx).models.watch()
636 }
637}
638
639impl acp_thread::AgentConnection for NativeAgentConnection {
640 fn new_thread(
641 self: Rc<Self>,
642 project: Entity<Project>,
643 cwd: &Path,
644 cx: &mut App,
645 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
646 let agent = self.0.clone();
647 log::info!("Creating new thread for project at: {:?}", cwd);
648
649 cx.spawn(async move |cx| {
650 log::debug!("Starting thread creation in async context");
651
652 // Generate session ID
653 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
654 log::info!("Created session with ID: {}", session_id);
655
656 // Create AcpThread
657 let acp_thread = cx.update(|cx| {
658 cx.new(|cx| {
659 acp_thread::AcpThread::new(
660 "agent2",
661 self.clone(),
662 project.clone(),
663 session_id.clone(),
664 cx,
665 )
666 })
667 })?;
668 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
669
670 // Create Thread
671 let thread = agent.update(
672 cx,
673 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
674 // Fetch default model from registry settings
675 let registry = LanguageModelRegistry::read_global(cx);
676
677 // Log available models for debugging
678 let available_count = registry.available_models(cx).count();
679 log::debug!("Total available models: {}", available_count);
680
681 let default_model = registry
682 .default_model()
683 .and_then(|default_model| {
684 agent
685 .models
686 .model_from_id(&LanguageModels::model_id(&default_model.model))
687 })
688 .ok_or_else(|| {
689 log::warn!("No default model configured in settings");
690 anyhow!(
691 "No default model. Please configure a default model in settings."
692 )
693 })?;
694
695 let thread = cx.new(|cx| {
696 let mut thread = Thread::new(
697 project.clone(),
698 agent.project_context.clone(),
699 agent.context_server_registry.clone(),
700 action_log.clone(),
701 agent.templates.clone(),
702 default_model,
703 cx,
704 );
705 thread.add_tool(CopyPathTool::new(project.clone()));
706 thread.add_tool(CreateDirectoryTool::new(project.clone()));
707 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
708 thread.add_tool(DiagnosticsTool::new(project.clone()));
709 thread.add_tool(EditFileTool::new(cx.entity()));
710 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
711 thread.add_tool(FindPathTool::new(project.clone()));
712 thread.add_tool(GrepTool::new(project.clone()));
713 thread.add_tool(ListDirectoryTool::new(project.clone()));
714 thread.add_tool(MovePathTool::new(project.clone()));
715 thread.add_tool(NowTool);
716 thread.add_tool(OpenTool::new(project.clone()));
717 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
718 thread.add_tool(TerminalTool::new(project.clone(), cx));
719 thread.add_tool(ThinkingTool);
720 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
721 thread
722 });
723
724 Ok(thread)
725 },
726 )??;
727
728 // Store the session
729 agent.update(cx, |agent, cx| {
730 agent.sessions.insert(
731 session_id,
732 Session {
733 thread,
734 acp_thread: acp_thread.downgrade(),
735 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
736 this.sessions.remove(acp_thread.session_id());
737 }),
738 },
739 );
740 })?;
741
742 Ok(acp_thread)
743 })
744 }
745
746 fn auth_methods(&self) -> &[acp::AuthMethod] {
747 &[] // No auth for in-process
748 }
749
750 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
751 Task::ready(Ok(()))
752 }
753
754 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
755 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
756 }
757
758 fn prompt(
759 &self,
760 id: Option<acp_thread::UserMessageId>,
761 params: acp::PromptRequest,
762 cx: &mut App,
763 ) -> Task<Result<acp::PromptResponse>> {
764 let id = id.expect("UserMessageId is required");
765 let session_id = params.session_id.clone();
766 log::info!("Received prompt request for session: {}", session_id);
767 log::debug!("Prompt blocks count: {}", params.prompt.len());
768
769 self.run_turn(session_id, cx, |thread, cx| {
770 let content: Vec<UserMessageContent> = params
771 .prompt
772 .into_iter()
773 .map(Into::into)
774 .collect::<Vec<_>>();
775 log::info!("Converted prompt to message: {} chars", content.len());
776 log::debug!("Message id: {:?}", id);
777 log::debug!("Message content: {:?}", content);
778
779 Ok(thread.update(cx, |thread, cx| {
780 log::info!(
781 "Sending message to thread with model: {:?}",
782 thread.model().name()
783 );
784 thread.send(id, content, cx)
785 }))
786 })
787 }
788
789 fn resume(
790 &self,
791 session_id: &acp::SessionId,
792 _cx: &mut App,
793 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
794 Some(Rc::new(NativeAgentSessionResume {
795 connection: self.clone(),
796 session_id: session_id.clone(),
797 }) as _)
798 }
799
800 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
801 log::info!("Cancelling on session: {}", session_id);
802 self.0.update(cx, |agent, cx| {
803 if let Some(agent) = agent.sessions.get(session_id) {
804 agent.thread.update(cx, |thread, _cx| thread.cancel());
805 }
806 });
807 }
808
809 fn session_editor(
810 &self,
811 session_id: &agent_client_protocol::SessionId,
812 cx: &mut App,
813 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
814 self.0.update(cx, |agent, _cx| {
815 agent
816 .sessions
817 .get(session_id)
818 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
819 })
820 }
821
822 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
823 self
824 }
825}
826
827struct NativeAgentSessionEditor(Entity<Thread>);
828
829impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
830 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
831 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
832 }
833}
834
835struct NativeAgentSessionResume {
836 connection: NativeAgentConnection,
837 session_id: acp::SessionId,
838}
839
840impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
841 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
842 self.connection
843 .run_turn(self.session_id.clone(), cx, |thread, cx| {
844 thread.update(cx, |thread, cx| thread.resume(cx))
845 })
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
853 use fs::FakeFs;
854 use gpui::TestAppContext;
855 use serde_json::json;
856 use settings::SettingsStore;
857
858 #[gpui::test]
859 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
860 init_test(cx);
861 let fs = FakeFs::new(cx.executor());
862 fs.insert_tree(
863 "/",
864 json!({
865 "a": {}
866 }),
867 )
868 .await;
869 let project = Project::test(fs.clone(), [], cx).await;
870 let agent = NativeAgent::new(
871 project.clone(),
872 Templates::new(),
873 None,
874 fs.clone(),
875 &mut cx.to_async(),
876 )
877 .await
878 .unwrap();
879 agent.read_with(cx, |agent, _| {
880 assert_eq!(agent.project_context.borrow().worktrees, vec![])
881 });
882
883 let worktree = project
884 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
885 .await
886 .unwrap();
887 cx.run_until_parked();
888 agent.read_with(cx, |agent, _| {
889 assert_eq!(
890 agent.project_context.borrow().worktrees,
891 vec![WorktreeContext {
892 root_name: "a".into(),
893 abs_path: Path::new("/a").into(),
894 rules_file: None
895 }]
896 )
897 });
898
899 // Creating `/a/.rules` updates the project context.
900 fs.insert_file("/a/.rules", Vec::new()).await;
901 cx.run_until_parked();
902 agent.read_with(cx, |agent, cx| {
903 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
904 assert_eq!(
905 agent.project_context.borrow().worktrees,
906 vec![WorktreeContext {
907 root_name: "a".into(),
908 abs_path: Path::new("/a").into(),
909 rules_file: Some(RulesFileContext {
910 path_in_worktree: Path::new(".rules").into(),
911 text: "".into(),
912 project_entry_id: rules_entry.id.to_usize()
913 })
914 }]
915 )
916 });
917 }
918
919 #[gpui::test]
920 async fn test_listing_models(cx: &mut TestAppContext) {
921 init_test(cx);
922 let fs = FakeFs::new(cx.executor());
923 fs.insert_tree("/", json!({ "a": {} })).await;
924 let project = Project::test(fs.clone(), [], cx).await;
925 let connection = NativeAgentConnection(
926 NativeAgent::new(
927 project.clone(),
928 Templates::new(),
929 None,
930 fs.clone(),
931 &mut cx.to_async(),
932 )
933 .await
934 .unwrap(),
935 );
936
937 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
938
939 let acp_thread::AgentModelList::Grouped(models) = models else {
940 panic!("Unexpected model group");
941 };
942 assert_eq!(
943 models,
944 IndexMap::from_iter([(
945 AgentModelGroupName("Fake".into()),
946 vec![AgentModelInfo {
947 id: AgentModelId("fake/fake".into()),
948 name: "Fake".into(),
949 icon: Some(ui::IconName::ZedAssistant),
950 }]
951 )])
952 );
953 }
954
955 #[gpui::test]
956 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
957 init_test(cx);
958 let fs = FakeFs::new(cx.executor());
959 fs.create_dir(paths::settings_file().parent().unwrap())
960 .await
961 .unwrap();
962 fs.insert_file(
963 paths::settings_file(),
964 json!({
965 "agent": {
966 "default_model": {
967 "provider": "foo",
968 "model": "bar"
969 }
970 }
971 })
972 .to_string()
973 .into_bytes(),
974 )
975 .await;
976 let project = Project::test(fs.clone(), [], cx).await;
977
978 // Create the agent and connection
979 let agent = NativeAgent::new(
980 project.clone(),
981 Templates::new(),
982 None,
983 fs.clone(),
984 &mut cx.to_async(),
985 )
986 .await
987 .unwrap();
988 let connection = NativeAgentConnection(agent.clone());
989
990 // Create a thread/session
991 let acp_thread = cx
992 .update(|cx| {
993 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
994 })
995 .await
996 .unwrap();
997
998 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
999
1000 // Select a model
1001 let model_id = AgentModelId("fake/fake".into());
1002 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1003 .await
1004 .unwrap();
1005
1006 // Verify the thread has the selected model
1007 agent.read_with(cx, |agent, _| {
1008 let session = agent.sessions.get(&session_id).unwrap();
1009 session.thread.read_with(cx, |thread, _| {
1010 assert_eq!(thread.model().id().0, "fake");
1011 });
1012 });
1013
1014 cx.run_until_parked();
1015
1016 // Verify settings file was updated
1017 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1018 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1019
1020 // Check that the agent settings contain the selected model
1021 assert_eq!(
1022 settings_json["agent"]["default_model"]["model"],
1023 json!("fake")
1024 );
1025 assert_eq!(
1026 settings_json["agent"]["default_model"]["provider"],
1027 json!("fake")
1028 );
1029 }
1030
1031 fn init_test(cx: &mut TestAppContext) {
1032 env_logger::try_init().ok();
1033 cx.update(|cx| {
1034 let settings_store = SettingsStore::test(cx);
1035 cx.set_global(settings_store);
1036 Project::init_settings(cx);
1037 agent_settings::init(cx);
1038 language::init(cx);
1039 LanguageModelRegistry::test(cx);
1040 });
1041 }
1042}