1use crate::{
2 AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
3 DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
4 MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
5 ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
6};
7use acp_thread::AgentModelSelector;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::cell::RefCell;
26use std::collections::HashMap;
27use std::path::Path;
28use std::rc::Rc;
29use std::sync::Arc;
30use util::ResultExt;
31
32const RULES_FILE_NAMES: [&'static str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 _subscription: Subscription,
55}
56
57pub struct LanguageModels {
58 /// Access language model by ID
59 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
60 /// Cached list for returning language model information
61 model_list: acp_thread::AgentModelList,
62 refresh_models_rx: watch::Receiver<()>,
63 refresh_models_tx: watch::Sender<()>,
64}
65
66impl LanguageModels {
67 fn new(cx: &App) -> Self {
68 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
69 let mut this = Self {
70 models: HashMap::default(),
71 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
72 refresh_models_rx,
73 refresh_models_tx,
74 };
75 this.refresh_list(cx);
76 this
77 }
78
79 fn refresh_list(&mut self, cx: &App) {
80 let providers = LanguageModelRegistry::global(cx)
81 .read(cx)
82 .providers()
83 .into_iter()
84 .filter(|provider| provider.is_authenticated(cx))
85 .collect::<Vec<_>>();
86
87 let mut language_model_list = IndexMap::default();
88 let mut recommended_models = HashSet::default();
89
90 let mut recommended = Vec::new();
91 for provider in &providers {
92 for model in provider.recommended_models(cx) {
93 recommended_models.insert(model.id());
94 recommended.push(Self::map_language_model_to_info(&model, &provider));
95 }
96 }
97 if !recommended.is_empty() {
98 language_model_list.insert(
99 acp_thread::AgentModelGroupName("Recommended".into()),
100 recommended,
101 );
102 }
103
104 let mut models = HashMap::default();
105 for provider in providers {
106 let mut provider_models = Vec::new();
107 for model in provider.provided_models(cx) {
108 let model_info = Self::map_language_model_to_info(&model, &provider);
109 let model_id = model_info.id.clone();
110 if !recommended_models.contains(&model.id()) {
111 provider_models.push(model_info);
112 }
113 models.insert(model_id, model);
114 }
115 if !provider_models.is_empty() {
116 language_model_list.insert(
117 acp_thread::AgentModelGroupName(provider.name().0.clone()),
118 provider_models,
119 );
120 }
121 }
122
123 self.models = models;
124 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
125 self.refresh_models_tx.send(()).ok();
126 }
127
128 fn watch(&self) -> watch::Receiver<()> {
129 self.refresh_models_rx.clone()
130 }
131
132 pub fn model_from_id(
133 &self,
134 model_id: &acp_thread::AgentModelId,
135 ) -> Option<Arc<dyn LanguageModel>> {
136 self.models.get(model_id).cloned()
137 }
138
139 fn map_language_model_to_info(
140 model: &Arc<dyn LanguageModel>,
141 provider: &Arc<dyn LanguageModelProvider>,
142 ) -> acp_thread::AgentModelInfo {
143 acp_thread::AgentModelInfo {
144 id: Self::model_id(model),
145 name: model.name().0,
146 icon: Some(provider.icon()),
147 }
148 }
149
150 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
151 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
152 }
153}
154
155pub struct NativeAgent {
156 /// Session ID -> Session mapping
157 sessions: HashMap<acp::SessionId, Session>,
158 /// Shared project context for all threads
159 project_context: Rc<RefCell<ProjectContext>>,
160 project_context_needs_refresh: watch::Sender<()>,
161 _maintain_project_context: Task<Result<()>>,
162 context_server_registry: Entity<ContextServerRegistry>,
163 /// Shared templates for all threads
164 templates: Arc<Templates>,
165 /// Cached model information
166 models: LanguageModels,
167 project: Entity<Project>,
168 prompt_store: Option<Entity<PromptStore>>,
169 fs: Arc<dyn Fs>,
170 _subscriptions: Vec<Subscription>,
171}
172
173impl NativeAgent {
174 pub async fn new(
175 project: Entity<Project>,
176 templates: Arc<Templates>,
177 prompt_store: Option<Entity<PromptStore>>,
178 fs: Arc<dyn Fs>,
179 cx: &mut AsyncApp,
180 ) -> Result<Entity<NativeAgent>> {
181 log::info!("Creating new NativeAgent");
182
183 let project_context = cx
184 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
185 .await;
186
187 cx.new(|cx| {
188 let mut subscriptions = vec![
189 cx.subscribe(&project, Self::handle_project_event),
190 cx.subscribe(
191 &LanguageModelRegistry::global(cx),
192 Self::handle_models_updated_event,
193 ),
194 ];
195 if let Some(prompt_store) = prompt_store.as_ref() {
196 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
197 }
198
199 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
200 watch::channel(());
201 Self {
202 sessions: HashMap::new(),
203 project_context: Rc::new(RefCell::new(project_context)),
204 project_context_needs_refresh: project_context_needs_refresh_tx,
205 _maintain_project_context: cx.spawn(async move |this, cx| {
206 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
207 }),
208 context_server_registry: cx.new(|cx| {
209 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
210 }),
211 templates,
212 models: LanguageModels::new(cx),
213 project,
214 prompt_store,
215 fs,
216 _subscriptions: subscriptions,
217 }
218 })
219 }
220
221 pub fn models(&self) -> &LanguageModels {
222 &self.models
223 }
224
225 async fn maintain_project_context(
226 this: WeakEntity<Self>,
227 mut needs_refresh: watch::Receiver<()>,
228 cx: &mut AsyncApp,
229 ) -> Result<()> {
230 while needs_refresh.changed().await.is_ok() {
231 let project_context = this
232 .update(cx, |this, cx| {
233 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
234 })?
235 .await;
236 this.update(cx, |this, _| this.project_context.replace(project_context))?;
237 }
238
239 Ok(())
240 }
241
242 fn build_project_context(
243 project: &Entity<Project>,
244 prompt_store: Option<&Entity<PromptStore>>,
245 cx: &mut App,
246 ) -> Task<ProjectContext> {
247 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
248 let worktree_tasks = worktrees
249 .into_iter()
250 .map(|worktree| {
251 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
252 })
253 .collect::<Vec<_>>();
254 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
255 prompt_store.read_with(cx, |prompt_store, cx| {
256 let prompts = prompt_store.default_prompt_metadata();
257 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
258 let contents = prompt_store.load(prompt_metadata.id, cx);
259 async move { (contents.await, prompt_metadata) }
260 });
261 cx.background_spawn(future::join_all(load_tasks))
262 })
263 } else {
264 Task::ready(vec![])
265 };
266
267 cx.spawn(async move |_cx| {
268 let (worktrees, default_user_rules) =
269 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
270
271 let worktrees = worktrees
272 .into_iter()
273 .map(|(worktree, _rules_error)| {
274 // TODO: show error message
275 // if let Some(rules_error) = rules_error {
276 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
277 // }
278 worktree
279 })
280 .collect::<Vec<_>>();
281
282 let default_user_rules = default_user_rules
283 .into_iter()
284 .flat_map(|(contents, prompt_metadata)| match contents {
285 Ok(contents) => Some(UserRulesContext {
286 uuid: match prompt_metadata.id {
287 PromptId::User { uuid } => uuid,
288 PromptId::EditWorkflow => return None,
289 },
290 title: prompt_metadata.title.map(|title| title.to_string()),
291 contents,
292 }),
293 Err(_err) => {
294 // TODO: show error message
295 // this.update(cx, |_, cx| {
296 // cx.emit(RulesLoadingError {
297 // message: format!("{err:?}").into(),
298 // });
299 // })
300 // .ok();
301 None
302 }
303 })
304 .collect::<Vec<_>>();
305
306 ProjectContext::new(worktrees, default_user_rules)
307 })
308 }
309
310 fn load_worktree_info_for_system_prompt(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
315 let tree = worktree.read(cx);
316 let root_name = tree.root_name().into();
317 let abs_path = tree.abs_path();
318
319 let mut context = WorktreeContext {
320 root_name,
321 abs_path,
322 rules_file: None,
323 };
324
325 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
326 let Some(rules_task) = rules_task else {
327 return Task::ready((context, None));
328 };
329
330 cx.spawn(async move |_| {
331 let (rules_file, rules_file_error) = match rules_task.await {
332 Ok(rules_file) => (Some(rules_file), None),
333 Err(err) => (
334 None,
335 Some(RulesLoadingError {
336 message: format!("{err}").into(),
337 }),
338 ),
339 };
340 context.rules_file = rules_file;
341 (context, rules_file_error)
342 })
343 }
344
345 fn load_worktree_rules_file(
346 worktree: Entity<Worktree>,
347 project: Entity<Project>,
348 cx: &mut App,
349 ) -> Option<Task<Result<RulesFileContext>>> {
350 let worktree = worktree.read(cx);
351 let worktree_id = worktree.id();
352 let selected_rules_file = RULES_FILE_NAMES
353 .into_iter()
354 .filter_map(|name| {
355 worktree
356 .entry_for_path(name)
357 .filter(|entry| entry.is_file())
358 .map(|entry| entry.path.clone())
359 })
360 .next();
361
362 // Note that Cline supports `.clinerules` being a directory, but that is not currently
363 // supported. This doesn't seem to occur often in GitHub repositories.
364 selected_rules_file.map(|path_in_worktree| {
365 let project_path = ProjectPath {
366 worktree_id,
367 path: path_in_worktree.clone(),
368 };
369 let buffer_task =
370 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
371 let rope_task = cx.spawn(async move |cx| {
372 buffer_task.await?.read_with(cx, |buffer, cx| {
373 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
374 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
375 })?
376 });
377 // Build a string from the rope on a background thread.
378 cx.background_spawn(async move {
379 let (project_entry_id, rope) = rope_task.await?;
380 anyhow::Ok(RulesFileContext {
381 path_in_worktree,
382 text: rope.to_string().trim().to_string(),
383 project_entry_id: project_entry_id.to_usize(),
384 })
385 })
386 })
387 }
388
389 fn handle_project_event(
390 &mut self,
391 _project: Entity<Project>,
392 event: &project::Event,
393 _cx: &mut Context<Self>,
394 ) {
395 match event {
396 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
397 self.project_context_needs_refresh.send(()).ok();
398 }
399 project::Event::WorktreeUpdatedEntries(_, items) => {
400 if items.iter().any(|(path, _, _)| {
401 RULES_FILE_NAMES
402 .iter()
403 .any(|name| path.as_ref() == Path::new(name))
404 }) {
405 self.project_context_needs_refresh.send(()).ok();
406 }
407 }
408 _ => {}
409 }
410 }
411
412 fn handle_prompts_updated_event(
413 &mut self,
414 _prompt_store: Entity<PromptStore>,
415 _event: &prompt_store::PromptsUpdatedEvent,
416 _cx: &mut Context<Self>,
417 ) {
418 self.project_context_needs_refresh.send(()).ok();
419 }
420
421 fn handle_models_updated_event(
422 &mut self,
423 _registry: Entity<LanguageModelRegistry>,
424 _event: &language_model::Event,
425 cx: &mut Context<Self>,
426 ) {
427 self.models.refresh_list(cx);
428 for session in self.sessions.values_mut() {
429 session.thread.update(cx, |thread, _| {
430 let model_id = LanguageModels::model_id(&thread.model());
431 if let Some(model) = self.models.model_from_id(&model_id) {
432 thread.set_model(model.clone());
433 }
434 });
435 }
436 }
437}
438
439/// Wrapper struct that implements the AgentConnection trait
440#[derive(Clone)]
441pub struct NativeAgentConnection(pub Entity<NativeAgent>);
442
443impl NativeAgentConnection {
444 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
445 self.0
446 .read(cx)
447 .sessions
448 .get(session_id)
449 .map(|session| session.thread.clone())
450 }
451
452 fn run_turn(
453 &self,
454 session_id: acp::SessionId,
455 cx: &mut App,
456 f: impl 'static
457 + FnOnce(
458 Entity<Thread>,
459 &mut App,
460 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
461 ) -> Task<Result<acp::PromptResponse>> {
462 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
463 agent
464 .sessions
465 .get_mut(&session_id)
466 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
467 }) else {
468 return Task::ready(Err(anyhow!("Session not found")));
469 };
470 log::debug!("Found session for: {}", session_id);
471
472 let mut response_stream = match f(thread, cx) {
473 Ok(stream) => stream,
474 Err(err) => return Task::ready(Err(err)),
475 };
476 cx.spawn(async move |cx| {
477 // Handle response stream and forward to session.acp_thread
478 while let Some(result) = response_stream.next().await {
479 match result {
480 Ok(event) => {
481 log::trace!("Received completion event: {:?}", event);
482
483 match event {
484 AgentResponseEvent::Text(text) => {
485 acp_thread.update(cx, |thread, cx| {
486 thread.push_assistant_content_block(
487 acp::ContentBlock::Text(acp::TextContent {
488 text,
489 annotations: None,
490 }),
491 false,
492 cx,
493 )
494 })?;
495 }
496 AgentResponseEvent::Thinking(text) => {
497 acp_thread.update(cx, |thread, cx| {
498 thread.push_assistant_content_block(
499 acp::ContentBlock::Text(acp::TextContent {
500 text,
501 annotations: None,
502 }),
503 true,
504 cx,
505 )
506 })?;
507 }
508 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
509 tool_call,
510 options,
511 response,
512 }) => {
513 let recv = acp_thread.update(cx, |thread, cx| {
514 thread.request_tool_call_authorization(tool_call, options, cx)
515 })?;
516 cx.background_spawn(async move {
517 if let Some(recv) = recv.log_err()
518 && let Some(option) = recv
519 .await
520 .context("authorization sender was dropped")
521 .log_err()
522 {
523 response
524 .send(option)
525 .map(|_| anyhow!("authorization receiver was dropped"))
526 .log_err();
527 }
528 })
529 .detach();
530 }
531 AgentResponseEvent::ToolCall(tool_call) => {
532 acp_thread.update(cx, |thread, cx| {
533 thread.upsert_tool_call(tool_call, cx)
534 })??;
535 }
536 AgentResponseEvent::ToolCallUpdate(update) => {
537 acp_thread.update(cx, |thread, cx| {
538 thread.update_tool_call(update, cx)
539 })??;
540 }
541 AgentResponseEvent::Stop(stop_reason) => {
542 log::debug!("Assistant message complete: {:?}", stop_reason);
543 return Ok(acp::PromptResponse { stop_reason });
544 }
545 }
546 }
547 Err(e) => {
548 log::error!("Error in model response stream: {:?}", e);
549 return Err(e);
550 }
551 }
552 }
553
554 log::info!("Response stream completed");
555 anyhow::Ok(acp::PromptResponse {
556 stop_reason: acp::StopReason::EndTurn,
557 })
558 })
559 }
560}
561
562impl AgentModelSelector for NativeAgentConnection {
563 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
564 log::debug!("NativeAgentConnection::list_models called");
565 let list = self.0.read(cx).models.model_list.clone();
566 Task::ready(if list.is_empty() {
567 Err(anyhow::anyhow!("No models available"))
568 } else {
569 Ok(list)
570 })
571 }
572
573 fn select_model(
574 &self,
575 session_id: acp::SessionId,
576 model_id: acp_thread::AgentModelId,
577 cx: &mut App,
578 ) -> Task<Result<()>> {
579 log::info!("Setting model for session {}: {}", session_id, model_id);
580 let Some(thread) = self
581 .0
582 .read(cx)
583 .sessions
584 .get(&session_id)
585 .map(|session| session.thread.clone())
586 else {
587 return Task::ready(Err(anyhow!("Session not found")));
588 };
589
590 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
591 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
592 };
593
594 thread.update(cx, |thread, _cx| {
595 thread.set_model(model.clone());
596 });
597
598 update_settings_file::<AgentSettings>(
599 self.0.read(cx).fs.clone(),
600 cx,
601 move |settings, _cx| {
602 settings.set_model(model);
603 },
604 );
605
606 Task::ready(Ok(()))
607 }
608
609 fn selected_model(
610 &self,
611 session_id: &acp::SessionId,
612 cx: &mut App,
613 ) -> Task<Result<acp_thread::AgentModelInfo>> {
614 let session_id = session_id.clone();
615
616 let Some(thread) = self
617 .0
618 .read(cx)
619 .sessions
620 .get(&session_id)
621 .map(|session| session.thread.clone())
622 else {
623 return Task::ready(Err(anyhow!("Session not found")));
624 };
625 let model = thread.read(cx).model().clone();
626 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
627 else {
628 return Task::ready(Err(anyhow!("Provider not found")));
629 };
630 Task::ready(Ok(LanguageModels::map_language_model_to_info(
631 &model, &provider,
632 )))
633 }
634
635 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
636 self.0.read(cx).models.watch()
637 }
638}
639
640impl acp_thread::AgentConnection for NativeAgentConnection {
641 fn new_thread(
642 self: Rc<Self>,
643 project: Entity<Project>,
644 cwd: &Path,
645 cx: &mut App,
646 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
647 let agent = self.0.clone();
648 log::info!("Creating new thread for project at: {:?}", cwd);
649
650 cx.spawn(async move |cx| {
651 log::debug!("Starting thread creation in async context");
652
653 // Generate session ID
654 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
655 log::info!("Created session with ID: {}", session_id);
656
657 // Create AcpThread
658 let acp_thread = cx.update(|cx| {
659 cx.new(|cx| {
660 acp_thread::AcpThread::new(
661 "agent2",
662 self.clone(),
663 project.clone(),
664 session_id.clone(),
665 cx,
666 )
667 })
668 })?;
669 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
670
671 // Create Thread
672 let thread = agent.update(
673 cx,
674 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
675 // Fetch default model from registry settings
676 let registry = LanguageModelRegistry::read_global(cx);
677
678 // Log available models for debugging
679 let available_count = registry.available_models(cx).count();
680 log::debug!("Total available models: {}", available_count);
681
682 let default_model = registry
683 .default_model()
684 .and_then(|default_model| {
685 agent
686 .models
687 .model_from_id(&LanguageModels::model_id(&default_model.model))
688 })
689 .ok_or_else(|| {
690 log::warn!("No default model configured in settings");
691 anyhow!(
692 "No default model. Please configure a default model in settings."
693 )
694 })?;
695
696 let thread = cx.new(|cx| {
697 let mut thread = Thread::new(
698 project.clone(),
699 agent.project_context.clone(),
700 agent.context_server_registry.clone(),
701 action_log.clone(),
702 agent.templates.clone(),
703 default_model,
704 cx,
705 );
706 thread.add_tool(CopyPathTool::new(project.clone()));
707 thread.add_tool(CreateDirectoryTool::new(project.clone()));
708 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
709 thread.add_tool(DiagnosticsTool::new(project.clone()));
710 thread.add_tool(EditFileTool::new(cx.entity()));
711 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
712 thread.add_tool(FindPathTool::new(project.clone()));
713 thread.add_tool(GrepTool::new(project.clone()));
714 thread.add_tool(ListDirectoryTool::new(project.clone()));
715 thread.add_tool(MovePathTool::new(project.clone()));
716 thread.add_tool(NowTool);
717 thread.add_tool(OpenTool::new(project.clone()));
718 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
719 thread.add_tool(TerminalTool::new(project.clone(), cx));
720 thread.add_tool(ThinkingTool);
721 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
722 thread
723 });
724
725 Ok(thread)
726 },
727 )??;
728
729 // Store the session
730 agent.update(cx, |agent, cx| {
731 agent.sessions.insert(
732 session_id,
733 Session {
734 thread,
735 acp_thread: acp_thread.downgrade(),
736 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
737 this.sessions.remove(acp_thread.session_id());
738 }),
739 },
740 );
741 })?;
742
743 Ok(acp_thread)
744 })
745 }
746
747 fn auth_methods(&self) -> &[acp::AuthMethod] {
748 &[] // No auth for in-process
749 }
750
751 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
752 Task::ready(Ok(()))
753 }
754
755 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
756 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
757 }
758
759 fn prompt(
760 &self,
761 id: Option<acp_thread::UserMessageId>,
762 params: acp::PromptRequest,
763 cx: &mut App,
764 ) -> Task<Result<acp::PromptResponse>> {
765 let id = id.expect("UserMessageId is required");
766 let session_id = params.session_id.clone();
767 log::info!("Received prompt request for session: {}", session_id);
768 log::debug!("Prompt blocks count: {}", params.prompt.len());
769
770 self.run_turn(session_id, cx, |thread, cx| {
771 let content: Vec<UserMessageContent> = params
772 .prompt
773 .into_iter()
774 .map(Into::into)
775 .collect::<Vec<_>>();
776 log::info!("Converted prompt to message: {} chars", content.len());
777 log::debug!("Message id: {:?}", id);
778 log::debug!("Message content: {:?}", content);
779
780 Ok(thread.update(cx, |thread, cx| {
781 log::info!(
782 "Sending message to thread with model: {:?}",
783 thread.model().name()
784 );
785 thread.send(id, content, cx)
786 }))
787 })
788 }
789
790 fn resume(
791 &self,
792 session_id: &acp::SessionId,
793 _cx: &mut App,
794 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
795 Some(Rc::new(NativeAgentSessionResume {
796 connection: self.clone(),
797 session_id: session_id.clone(),
798 }) as _)
799 }
800
801 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
802 log::info!("Cancelling on session: {}", session_id);
803 self.0.update(cx, |agent, cx| {
804 if let Some(agent) = agent.sessions.get(session_id) {
805 agent.thread.update(cx, |thread, _cx| thread.cancel());
806 }
807 });
808 }
809
810 fn session_editor(
811 &self,
812 session_id: &agent_client_protocol::SessionId,
813 cx: &mut App,
814 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
815 self.0.update(cx, |agent, _cx| {
816 agent
817 .sessions
818 .get(session_id)
819 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
820 })
821 }
822
823 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
824 self
825 }
826}
827
828struct NativeAgentSessionEditor(Entity<Thread>);
829
830impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
831 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
832 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
833 }
834}
835
836struct NativeAgentSessionResume {
837 connection: NativeAgentConnection,
838 session_id: acp::SessionId,
839}
840
841impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
842 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
843 self.connection
844 .run_turn(self.session_id.clone(), cx, |thread, cx| {
845 thread.update(cx, |thread, cx| thread.resume(cx))
846 })
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
854 use fs::FakeFs;
855 use gpui::TestAppContext;
856 use serde_json::json;
857 use settings::SettingsStore;
858
859 #[gpui::test]
860 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
861 init_test(cx);
862 let fs = FakeFs::new(cx.executor());
863 fs.insert_tree(
864 "/",
865 json!({
866 "a": {}
867 }),
868 )
869 .await;
870 let project = Project::test(fs.clone(), [], cx).await;
871 let agent = NativeAgent::new(
872 project.clone(),
873 Templates::new(),
874 None,
875 fs.clone(),
876 &mut cx.to_async(),
877 )
878 .await
879 .unwrap();
880 agent.read_with(cx, |agent, _| {
881 assert_eq!(agent.project_context.borrow().worktrees, vec![])
882 });
883
884 let worktree = project
885 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
886 .await
887 .unwrap();
888 cx.run_until_parked();
889 agent.read_with(cx, |agent, _| {
890 assert_eq!(
891 agent.project_context.borrow().worktrees,
892 vec![WorktreeContext {
893 root_name: "a".into(),
894 abs_path: Path::new("/a").into(),
895 rules_file: None
896 }]
897 )
898 });
899
900 // Creating `/a/.rules` updates the project context.
901 fs.insert_file("/a/.rules", Vec::new()).await;
902 cx.run_until_parked();
903 agent.read_with(cx, |agent, cx| {
904 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
905 assert_eq!(
906 agent.project_context.borrow().worktrees,
907 vec![WorktreeContext {
908 root_name: "a".into(),
909 abs_path: Path::new("/a").into(),
910 rules_file: Some(RulesFileContext {
911 path_in_worktree: Path::new(".rules").into(),
912 text: "".into(),
913 project_entry_id: rules_entry.id.to_usize()
914 })
915 }]
916 )
917 });
918 }
919
920 #[gpui::test]
921 async fn test_listing_models(cx: &mut TestAppContext) {
922 init_test(cx);
923 let fs = FakeFs::new(cx.executor());
924 fs.insert_tree("/", json!({ "a": {} })).await;
925 let project = Project::test(fs.clone(), [], cx).await;
926 let connection = NativeAgentConnection(
927 NativeAgent::new(
928 project.clone(),
929 Templates::new(),
930 None,
931 fs.clone(),
932 &mut cx.to_async(),
933 )
934 .await
935 .unwrap(),
936 );
937
938 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
939
940 let acp_thread::AgentModelList::Grouped(models) = models else {
941 panic!("Unexpected model group");
942 };
943 assert_eq!(
944 models,
945 IndexMap::from_iter([(
946 AgentModelGroupName("Fake".into()),
947 vec![AgentModelInfo {
948 id: AgentModelId("fake/fake".into()),
949 name: "Fake".into(),
950 icon: Some(ui::IconName::ZedAssistant),
951 }]
952 )])
953 );
954 }
955
956 #[gpui::test]
957 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
958 init_test(cx);
959 let fs = FakeFs::new(cx.executor());
960 fs.create_dir(paths::settings_file().parent().unwrap())
961 .await
962 .unwrap();
963 fs.insert_file(
964 paths::settings_file(),
965 json!({
966 "agent": {
967 "default_model": {
968 "provider": "foo",
969 "model": "bar"
970 }
971 }
972 })
973 .to_string()
974 .into_bytes(),
975 )
976 .await;
977 let project = Project::test(fs.clone(), [], cx).await;
978
979 // Create the agent and connection
980 let agent = NativeAgent::new(
981 project.clone(),
982 Templates::new(),
983 None,
984 fs.clone(),
985 &mut cx.to_async(),
986 )
987 .await
988 .unwrap();
989 let connection = NativeAgentConnection(agent.clone());
990
991 // Create a thread/session
992 let acp_thread = cx
993 .update(|cx| {
994 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
995 })
996 .await
997 .unwrap();
998
999 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1000
1001 // Select a model
1002 let model_id = AgentModelId("fake/fake".into());
1003 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1004 .await
1005 .unwrap();
1006
1007 // Verify the thread has the selected model
1008 agent.read_with(cx, |agent, _| {
1009 let session = agent.sessions.get(&session_id).unwrap();
1010 session.thread.read_with(cx, |thread, _| {
1011 assert_eq!(thread.model().id().0, "fake");
1012 });
1013 });
1014
1015 cx.run_until_parked();
1016
1017 // Verify settings file was updated
1018 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1019 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1020
1021 // Check that the agent settings contain the selected model
1022 assert_eq!(
1023 settings_json["agent"]["default_model"]["model"],
1024 json!("fake")
1025 );
1026 assert_eq!(
1027 settings_json["agent"]["default_model"]["provider"],
1028 json!("fake")
1029 );
1030 }
1031
1032 fn init_test(cx: &mut TestAppContext) {
1033 env_logger::try_init().ok();
1034 cx.update(|cx| {
1035 let settings_store = SettingsStore::test(cx);
1036 cx.set_global(settings_store);
1037 Project::init_settings(cx);
1038 agent_settings::init(cx);
1039 language::init(cx);
1040 LanguageModelRegistry::test(cx);
1041 });
1042 }
1043}