1use crate::{
2 AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
3 DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
4 MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
5 ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
6};
7use acp_thread::AgentModelSelector;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::collections::HashMap;
26use std::path::Path;
27use std::rc::Rc;
28use std::sync::Arc;
29use util::ResultExt;
30
31const RULES_FILE_NAMES: [&'static str; 9] = [
32 ".rules",
33 ".cursorrules",
34 ".windsurfrules",
35 ".clinerules",
36 ".github/copilot-instructions.md",
37 "CLAUDE.md",
38 "AGENT.md",
39 "AGENTS.md",
40 "GEMINI.md",
41];
42
43pub struct RulesLoadingError {
44 pub message: SharedString,
45}
46
47/// Holds both the internal Thread and the AcpThread for a session
48struct Session {
49 /// The internal thread that processes messages
50 thread: Entity<Thread>,
51 /// The ACP thread that handles protocol communication
52 acp_thread: WeakEntity<acp_thread::AcpThread>,
53 _subscription: Subscription,
54}
55
56pub struct LanguageModels {
57 /// Access language model by ID
58 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
59 /// Cached list for returning language model information
60 model_list: acp_thread::AgentModelList,
61 refresh_models_rx: watch::Receiver<()>,
62 refresh_models_tx: watch::Sender<()>,
63}
64
65impl LanguageModels {
66 fn new(cx: &App) -> Self {
67 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
68 let mut this = Self {
69 models: HashMap::default(),
70 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
71 refresh_models_rx,
72 refresh_models_tx,
73 };
74 this.refresh_list(cx);
75 this
76 }
77
78 fn refresh_list(&mut self, cx: &App) {
79 let providers = LanguageModelRegistry::global(cx)
80 .read(cx)
81 .providers()
82 .into_iter()
83 .filter(|provider| provider.is_authenticated(cx))
84 .collect::<Vec<_>>();
85
86 let mut language_model_list = IndexMap::default();
87 let mut recommended_models = HashSet::default();
88
89 let mut recommended = Vec::new();
90 for provider in &providers {
91 for model in provider.recommended_models(cx) {
92 recommended_models.insert(model.id());
93 recommended.push(Self::map_language_model_to_info(&model, provider));
94 }
95 }
96 if !recommended.is_empty() {
97 language_model_list.insert(
98 acp_thread::AgentModelGroupName("Recommended".into()),
99 recommended,
100 );
101 }
102
103 let mut models = HashMap::default();
104 for provider in providers {
105 let mut provider_models = Vec::new();
106 for model in provider.provided_models(cx) {
107 let model_info = Self::map_language_model_to_info(&model, &provider);
108 let model_id = model_info.id.clone();
109 if !recommended_models.contains(&model.id()) {
110 provider_models.push(model_info);
111 }
112 models.insert(model_id, model);
113 }
114 if !provider_models.is_empty() {
115 language_model_list.insert(
116 acp_thread::AgentModelGroupName(provider.name().0.clone()),
117 provider_models,
118 );
119 }
120 }
121
122 self.models = models;
123 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
124 self.refresh_models_tx.send(()).ok();
125 }
126
127 fn watch(&self) -> watch::Receiver<()> {
128 self.refresh_models_rx.clone()
129 }
130
131 pub fn model_from_id(
132 &self,
133 model_id: &acp_thread::AgentModelId,
134 ) -> Option<Arc<dyn LanguageModel>> {
135 self.models.get(model_id).cloned()
136 }
137
138 fn map_language_model_to_info(
139 model: &Arc<dyn LanguageModel>,
140 provider: &Arc<dyn LanguageModelProvider>,
141 ) -> acp_thread::AgentModelInfo {
142 acp_thread::AgentModelInfo {
143 id: Self::model_id(model),
144 name: model.name().0,
145 icon: Some(provider.icon()),
146 }
147 }
148
149 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
150 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
151 }
152}
153
154pub struct NativeAgent {
155 /// Session ID -> Session mapping
156 sessions: HashMap<acp::SessionId, Session>,
157 /// Shared project context for all threads
158 project_context: Entity<ProjectContext>,
159 project_context_needs_refresh: watch::Sender<()>,
160 _maintain_project_context: Task<Result<()>>,
161 context_server_registry: Entity<ContextServerRegistry>,
162 /// Shared templates for all threads
163 templates: Arc<Templates>,
164 /// Cached model information
165 models: LanguageModels,
166 project: Entity<Project>,
167 prompt_store: Option<Entity<PromptStore>>,
168 fs: Arc<dyn Fs>,
169 _subscriptions: Vec<Subscription>,
170}
171
172impl NativeAgent {
173 pub async fn new(
174 project: Entity<Project>,
175 templates: Arc<Templates>,
176 prompt_store: Option<Entity<PromptStore>>,
177 fs: Arc<dyn Fs>,
178 cx: &mut AsyncApp,
179 ) -> Result<Entity<NativeAgent>> {
180 log::info!("Creating new NativeAgent");
181
182 let project_context = cx
183 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
184 .await;
185
186 cx.new(|cx| {
187 let mut subscriptions = vec![
188 cx.subscribe(&project, Self::handle_project_event),
189 cx.subscribe(
190 &LanguageModelRegistry::global(cx),
191 Self::handle_models_updated_event,
192 ),
193 ];
194 if let Some(prompt_store) = prompt_store.as_ref() {
195 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
196 }
197
198 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
199 watch::channel(());
200 Self {
201 sessions: HashMap::new(),
202 project_context: cx.new(|_| project_context),
203 project_context_needs_refresh: project_context_needs_refresh_tx,
204 _maintain_project_context: cx.spawn(async move |this, cx| {
205 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
206 }),
207 context_server_registry: cx.new(|cx| {
208 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
209 }),
210 templates,
211 models: LanguageModels::new(cx),
212 project,
213 prompt_store,
214 fs,
215 _subscriptions: subscriptions,
216 }
217 })
218 }
219
220 pub fn models(&self) -> &LanguageModels {
221 &self.models
222 }
223
224 async fn maintain_project_context(
225 this: WeakEntity<Self>,
226 mut needs_refresh: watch::Receiver<()>,
227 cx: &mut AsyncApp,
228 ) -> Result<()> {
229 while needs_refresh.changed().await.is_ok() {
230 let project_context = this
231 .update(cx, |this, cx| {
232 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
233 })?
234 .await;
235 this.update(cx, |this, cx| {
236 this.project_context = cx.new(|_| project_context);
237 })?;
238 }
239
240 Ok(())
241 }
242
243 fn build_project_context(
244 project: &Entity<Project>,
245 prompt_store: Option<&Entity<PromptStore>>,
246 cx: &mut App,
247 ) -> Task<ProjectContext> {
248 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
249 let worktree_tasks = worktrees
250 .into_iter()
251 .map(|worktree| {
252 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
253 })
254 .collect::<Vec<_>>();
255 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
256 prompt_store.read_with(cx, |prompt_store, cx| {
257 let prompts = prompt_store.default_prompt_metadata();
258 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
259 let contents = prompt_store.load(prompt_metadata.id, cx);
260 async move { (contents.await, prompt_metadata) }
261 });
262 cx.background_spawn(future::join_all(load_tasks))
263 })
264 } else {
265 Task::ready(vec![])
266 };
267
268 cx.spawn(async move |_cx| {
269 let (worktrees, default_user_rules) =
270 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
271
272 let worktrees = worktrees
273 .into_iter()
274 .map(|(worktree, _rules_error)| {
275 // TODO: show error message
276 // if let Some(rules_error) = rules_error {
277 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
278 // }
279 worktree
280 })
281 .collect::<Vec<_>>();
282
283 let default_user_rules = default_user_rules
284 .into_iter()
285 .flat_map(|(contents, prompt_metadata)| match contents {
286 Ok(contents) => Some(UserRulesContext {
287 uuid: match prompt_metadata.id {
288 PromptId::User { uuid } => uuid,
289 PromptId::EditWorkflow => return None,
290 },
291 title: prompt_metadata.title.map(|title| title.to_string()),
292 contents,
293 }),
294 Err(_err) => {
295 // TODO: show error message
296 // this.update(cx, |_, cx| {
297 // cx.emit(RulesLoadingError {
298 // message: format!("{err:?}").into(),
299 // });
300 // })
301 // .ok();
302 None
303 }
304 })
305 .collect::<Vec<_>>();
306
307 ProjectContext::new(worktrees, default_user_rules)
308 })
309 }
310
311 fn load_worktree_info_for_system_prompt(
312 worktree: Entity<Worktree>,
313 project: Entity<Project>,
314 cx: &mut App,
315 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
316 let tree = worktree.read(cx);
317 let root_name = tree.root_name().into();
318 let abs_path = tree.abs_path();
319
320 let mut context = WorktreeContext {
321 root_name,
322 abs_path,
323 rules_file: None,
324 };
325
326 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
327 let Some(rules_task) = rules_task else {
328 return Task::ready((context, None));
329 };
330
331 cx.spawn(async move |_| {
332 let (rules_file, rules_file_error) = match rules_task.await {
333 Ok(rules_file) => (Some(rules_file), None),
334 Err(err) => (
335 None,
336 Some(RulesLoadingError {
337 message: format!("{err}").into(),
338 }),
339 ),
340 };
341 context.rules_file = rules_file;
342 (context, rules_file_error)
343 })
344 }
345
346 fn load_worktree_rules_file(
347 worktree: Entity<Worktree>,
348 project: Entity<Project>,
349 cx: &mut App,
350 ) -> Option<Task<Result<RulesFileContext>>> {
351 let worktree = worktree.read(cx);
352 let worktree_id = worktree.id();
353 let selected_rules_file = RULES_FILE_NAMES
354 .into_iter()
355 .filter_map(|name| {
356 worktree
357 .entry_for_path(name)
358 .filter(|entry| entry.is_file())
359 .map(|entry| entry.path.clone())
360 })
361 .next();
362
363 // Note that Cline supports `.clinerules` being a directory, but that is not currently
364 // supported. This doesn't seem to occur often in GitHub repositories.
365 selected_rules_file.map(|path_in_worktree| {
366 let project_path = ProjectPath {
367 worktree_id,
368 path: path_in_worktree.clone(),
369 };
370 let buffer_task =
371 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
372 let rope_task = cx.spawn(async move |cx| {
373 buffer_task.await?.read_with(cx, |buffer, cx| {
374 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
375 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
376 })?
377 });
378 // Build a string from the rope on a background thread.
379 cx.background_spawn(async move {
380 let (project_entry_id, rope) = rope_task.await?;
381 anyhow::Ok(RulesFileContext {
382 path_in_worktree,
383 text: rope.to_string().trim().to_string(),
384 project_entry_id: project_entry_id.to_usize(),
385 })
386 })
387 })
388 }
389
390 fn handle_project_event(
391 &mut self,
392 _project: Entity<Project>,
393 event: &project::Event,
394 _cx: &mut Context<Self>,
395 ) {
396 match event {
397 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
398 self.project_context_needs_refresh.send(()).ok();
399 }
400 project::Event::WorktreeUpdatedEntries(_, items) => {
401 if items.iter().any(|(path, _, _)| {
402 RULES_FILE_NAMES
403 .iter()
404 .any(|name| path.as_ref() == Path::new(name))
405 }) {
406 self.project_context_needs_refresh.send(()).ok();
407 }
408 }
409 _ => {}
410 }
411 }
412
413 fn handle_prompts_updated_event(
414 &mut self,
415 _prompt_store: Entity<PromptStore>,
416 _event: &prompt_store::PromptsUpdatedEvent,
417 _cx: &mut Context<Self>,
418 ) {
419 self.project_context_needs_refresh.send(()).ok();
420 }
421
422 fn handle_models_updated_event(
423 &mut self,
424 _registry: Entity<LanguageModelRegistry>,
425 _event: &language_model::Event,
426 cx: &mut Context<Self>,
427 ) {
428 self.models.refresh_list(cx);
429
430 let default_model = LanguageModelRegistry::read_global(cx)
431 .default_model()
432 .map(|m| m.model.clone());
433
434 for session in self.sessions.values_mut() {
435 session.thread.update(cx, |thread, cx| {
436 if thread.model().is_none()
437 && let Some(model) = default_model.clone()
438 {
439 thread.set_model(model);
440 cx.notify();
441 }
442 });
443 }
444 }
445}
446
447/// Wrapper struct that implements the AgentConnection trait
448#[derive(Clone)]
449pub struct NativeAgentConnection(pub Entity<NativeAgent>);
450
451impl NativeAgentConnection {
452 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
453 self.0
454 .read(cx)
455 .sessions
456 .get(session_id)
457 .map(|session| session.thread.clone())
458 }
459
460 fn run_turn(
461 &self,
462 session_id: acp::SessionId,
463 cx: &mut App,
464 f: impl 'static
465 + FnOnce(
466 Entity<Thread>,
467 &mut App,
468 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
469 ) -> Task<Result<acp::PromptResponse>> {
470 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
471 agent
472 .sessions
473 .get_mut(&session_id)
474 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
475 }) else {
476 return Task::ready(Err(anyhow!("Session not found")));
477 };
478 log::debug!("Found session for: {}", session_id);
479
480 let mut response_stream = match f(thread, cx) {
481 Ok(stream) => stream,
482 Err(err) => return Task::ready(Err(err)),
483 };
484 cx.spawn(async move |cx| {
485 // Handle response stream and forward to session.acp_thread
486 while let Some(result) = response_stream.next().await {
487 match result {
488 Ok(event) => {
489 log::trace!("Received completion event: {:?}", event);
490
491 match event {
492 AgentResponseEvent::Text(text) => {
493 acp_thread.update(cx, |thread, cx| {
494 thread.push_assistant_content_block(
495 acp::ContentBlock::Text(acp::TextContent {
496 text,
497 annotations: None,
498 }),
499 false,
500 cx,
501 )
502 })?;
503 }
504 AgentResponseEvent::Thinking(text) => {
505 acp_thread.update(cx, |thread, cx| {
506 thread.push_assistant_content_block(
507 acp::ContentBlock::Text(acp::TextContent {
508 text,
509 annotations: None,
510 }),
511 true,
512 cx,
513 )
514 })?;
515 }
516 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
517 tool_call,
518 options,
519 response,
520 }) => {
521 let recv = acp_thread.update(cx, |thread, cx| {
522 thread.request_tool_call_authorization(tool_call, options, cx)
523 })?;
524 cx.background_spawn(async move {
525 if let Some(recv) = recv.log_err()
526 && let Some(option) = recv
527 .await
528 .context("authorization sender was dropped")
529 .log_err()
530 {
531 response
532 .send(option)
533 .map(|_| anyhow!("authorization receiver was dropped"))
534 .log_err();
535 }
536 })
537 .detach();
538 }
539 AgentResponseEvent::ToolCall(tool_call) => {
540 acp_thread.update(cx, |thread, cx| {
541 thread.upsert_tool_call(tool_call, cx)
542 })??;
543 }
544 AgentResponseEvent::ToolCallUpdate(update) => {
545 acp_thread.update(cx, |thread, cx| {
546 thread.update_tool_call(update, cx)
547 })??;
548 }
549 AgentResponseEvent::Stop(stop_reason) => {
550 log::debug!("Assistant message complete: {:?}", stop_reason);
551 return Ok(acp::PromptResponse { stop_reason });
552 }
553 }
554 }
555 Err(e) => {
556 log::error!("Error in model response stream: {:?}", e);
557 return Err(e);
558 }
559 }
560 }
561
562 log::info!("Response stream completed");
563 anyhow::Ok(acp::PromptResponse {
564 stop_reason: acp::StopReason::EndTurn,
565 })
566 })
567 }
568}
569
570impl AgentModelSelector for NativeAgentConnection {
571 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
572 log::debug!("NativeAgentConnection::list_models called");
573 let list = self.0.read(cx).models.model_list.clone();
574 Task::ready(if list.is_empty() {
575 Err(anyhow::anyhow!("No models available"))
576 } else {
577 Ok(list)
578 })
579 }
580
581 fn select_model(
582 &self,
583 session_id: acp::SessionId,
584 model_id: acp_thread::AgentModelId,
585 cx: &mut App,
586 ) -> Task<Result<()>> {
587 log::info!("Setting model for session {}: {}", session_id, model_id);
588 let Some(thread) = self
589 .0
590 .read(cx)
591 .sessions
592 .get(&session_id)
593 .map(|session| session.thread.clone())
594 else {
595 return Task::ready(Err(anyhow!("Session not found")));
596 };
597
598 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
599 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
600 };
601
602 thread.update(cx, |thread, _cx| {
603 thread.set_model(model.clone());
604 });
605
606 update_settings_file::<AgentSettings>(
607 self.0.read(cx).fs.clone(),
608 cx,
609 move |settings, _cx| {
610 settings.set_model(model);
611 },
612 );
613
614 Task::ready(Ok(()))
615 }
616
617 fn selected_model(
618 &self,
619 session_id: &acp::SessionId,
620 cx: &mut App,
621 ) -> Task<Result<acp_thread::AgentModelInfo>> {
622 let session_id = session_id.clone();
623
624 let Some(thread) = self
625 .0
626 .read(cx)
627 .sessions
628 .get(&session_id)
629 .map(|session| session.thread.clone())
630 else {
631 return Task::ready(Err(anyhow!("Session not found")));
632 };
633 let Some(model) = thread.read(cx).model() else {
634 return Task::ready(Err(anyhow!("Model not found")));
635 };
636 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
637 else {
638 return Task::ready(Err(anyhow!("Provider not found")));
639 };
640 Task::ready(Ok(LanguageModels::map_language_model_to_info(
641 model, &provider,
642 )))
643 }
644
645 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
646 self.0.read(cx).models.watch()
647 }
648}
649
650impl acp_thread::AgentConnection for NativeAgentConnection {
651 fn new_thread(
652 self: Rc<Self>,
653 project: Entity<Project>,
654 cwd: &Path,
655 cx: &mut App,
656 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
657 let agent = self.0.clone();
658 log::info!("Creating new thread for project at: {:?}", cwd);
659
660 cx.spawn(async move |cx| {
661 log::debug!("Starting thread creation in async context");
662
663 // Generate session ID
664 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
665 log::info!("Created session with ID: {}", session_id);
666
667 // Create AcpThread
668 let acp_thread = cx.update(|cx| {
669 cx.new(|cx| {
670 acp_thread::AcpThread::new(
671 "agent2",
672 self.clone(),
673 project.clone(),
674 session_id.clone(),
675 cx,
676 )
677 })
678 })?;
679 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
680
681 // Create Thread
682 let thread = agent.update(
683 cx,
684 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
685 // Fetch default model from registry settings
686 let registry = LanguageModelRegistry::read_global(cx);
687
688 // Log available models for debugging
689 let available_count = registry.available_models(cx).count();
690 log::debug!("Total available models: {}", available_count);
691
692 let default_model = registry.default_model().and_then(|default_model| {
693 agent
694 .models
695 .model_from_id(&LanguageModels::model_id(&default_model.model))
696 });
697
698 let thread = cx.new(|cx| {
699 let mut thread = Thread::new(
700 project.clone(),
701 agent.project_context.clone(),
702 agent.context_server_registry.clone(),
703 action_log.clone(),
704 agent.templates.clone(),
705 default_model,
706 cx,
707 );
708 thread.add_tool(CopyPathTool::new(project.clone()));
709 thread.add_tool(CreateDirectoryTool::new(project.clone()));
710 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
711 thread.add_tool(DiagnosticsTool::new(project.clone()));
712 thread.add_tool(EditFileTool::new(cx.entity()));
713 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
714 thread.add_tool(FindPathTool::new(project.clone()));
715 thread.add_tool(GrepTool::new(project.clone()));
716 thread.add_tool(ListDirectoryTool::new(project.clone()));
717 thread.add_tool(MovePathTool::new(project.clone()));
718 thread.add_tool(NowTool);
719 thread.add_tool(OpenTool::new(project.clone()));
720 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
721 thread.add_tool(TerminalTool::new(project.clone(), cx));
722 thread.add_tool(ThinkingTool);
723 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
724 thread
725 });
726
727 Ok(thread)
728 },
729 )??;
730
731 // Store the session
732 agent.update(cx, |agent, cx| {
733 agent.sessions.insert(
734 session_id,
735 Session {
736 thread,
737 acp_thread: acp_thread.downgrade(),
738 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
739 this.sessions.remove(acp_thread.session_id());
740 }),
741 },
742 );
743 })?;
744
745 Ok(acp_thread)
746 })
747 }
748
749 fn auth_methods(&self) -> &[acp::AuthMethod] {
750 &[] // No auth for in-process
751 }
752
753 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
754 Task::ready(Ok(()))
755 }
756
757 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
758 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
759 }
760
761 fn prompt(
762 &self,
763 id: Option<acp_thread::UserMessageId>,
764 params: acp::PromptRequest,
765 cx: &mut App,
766 ) -> Task<Result<acp::PromptResponse>> {
767 let id = id.expect("UserMessageId is required");
768 let session_id = params.session_id.clone();
769 log::info!("Received prompt request for session: {}", session_id);
770 log::debug!("Prompt blocks count: {}", params.prompt.len());
771
772 self.run_turn(session_id, cx, |thread, cx| {
773 let content: Vec<UserMessageContent> = params
774 .prompt
775 .into_iter()
776 .map(Into::into)
777 .collect::<Vec<_>>();
778 log::info!("Converted prompt to message: {} chars", content.len());
779 log::debug!("Message id: {:?}", id);
780 log::debug!("Message content: {:?}", content);
781
782 thread.update(cx, |thread, cx| thread.send(id, content, cx))
783 })
784 }
785
786 fn resume(
787 &self,
788 session_id: &acp::SessionId,
789 _cx: &mut App,
790 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
791 Some(Rc::new(NativeAgentSessionResume {
792 connection: self.clone(),
793 session_id: session_id.clone(),
794 }) as _)
795 }
796
797 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
798 log::info!("Cancelling on session: {}", session_id);
799 self.0.update(cx, |agent, cx| {
800 if let Some(agent) = agent.sessions.get(session_id) {
801 agent.thread.update(cx, |thread, _cx| thread.cancel());
802 }
803 });
804 }
805
806 fn session_editor(
807 &self,
808 session_id: &agent_client_protocol::SessionId,
809 cx: &mut App,
810 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
811 self.0.update(cx, |agent, _cx| {
812 agent
813 .sessions
814 .get(session_id)
815 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
816 })
817 }
818
819 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
820 self
821 }
822}
823
824struct NativeAgentSessionEditor(Entity<Thread>);
825
826impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
827 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
828 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
829 }
830}
831
832struct NativeAgentSessionResume {
833 connection: NativeAgentConnection,
834 session_id: acp::SessionId,
835}
836
837impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
838 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
839 self.connection
840 .run_turn(self.session_id.clone(), cx, |thread, cx| {
841 thread.update(cx, |thread, cx| thread.resume(cx))
842 })
843 }
844}
845
846#[cfg(test)]
847mod tests {
848 use super::*;
849 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
850 use fs::FakeFs;
851 use gpui::TestAppContext;
852 use serde_json::json;
853 use settings::SettingsStore;
854
855 #[gpui::test]
856 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
857 init_test(cx);
858 let fs = FakeFs::new(cx.executor());
859 fs.insert_tree(
860 "/",
861 json!({
862 "a": {}
863 }),
864 )
865 .await;
866 let project = Project::test(fs.clone(), [], cx).await;
867 let agent = NativeAgent::new(
868 project.clone(),
869 Templates::new(),
870 None,
871 fs.clone(),
872 &mut cx.to_async(),
873 )
874 .await
875 .unwrap();
876 agent.read_with(cx, |agent, cx| {
877 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
878 });
879
880 let worktree = project
881 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
882 .await
883 .unwrap();
884 cx.run_until_parked();
885 agent.read_with(cx, |agent, cx| {
886 assert_eq!(
887 agent.project_context.read(cx).worktrees,
888 vec![WorktreeContext {
889 root_name: "a".into(),
890 abs_path: Path::new("/a").into(),
891 rules_file: None
892 }]
893 )
894 });
895
896 // Creating `/a/.rules` updates the project context.
897 fs.insert_file("/a/.rules", Vec::new()).await;
898 cx.run_until_parked();
899 agent.read_with(cx, |agent, cx| {
900 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
901 assert_eq!(
902 agent.project_context.read(cx).worktrees,
903 vec![WorktreeContext {
904 root_name: "a".into(),
905 abs_path: Path::new("/a").into(),
906 rules_file: Some(RulesFileContext {
907 path_in_worktree: Path::new(".rules").into(),
908 text: "".into(),
909 project_entry_id: rules_entry.id.to_usize()
910 })
911 }]
912 )
913 });
914 }
915
916 #[gpui::test]
917 async fn test_listing_models(cx: &mut TestAppContext) {
918 init_test(cx);
919 let fs = FakeFs::new(cx.executor());
920 fs.insert_tree("/", json!({ "a": {} })).await;
921 let project = Project::test(fs.clone(), [], cx).await;
922 let connection = NativeAgentConnection(
923 NativeAgent::new(
924 project.clone(),
925 Templates::new(),
926 None,
927 fs.clone(),
928 &mut cx.to_async(),
929 )
930 .await
931 .unwrap(),
932 );
933
934 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
935
936 let acp_thread::AgentModelList::Grouped(models) = models else {
937 panic!("Unexpected model group");
938 };
939 assert_eq!(
940 models,
941 IndexMap::from_iter([(
942 AgentModelGroupName("Fake".into()),
943 vec![AgentModelInfo {
944 id: AgentModelId("fake/fake".into()),
945 name: "Fake".into(),
946 icon: Some(ui::IconName::ZedAssistant),
947 }]
948 )])
949 );
950 }
951
952 #[gpui::test]
953 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
954 init_test(cx);
955 let fs = FakeFs::new(cx.executor());
956 fs.create_dir(paths::settings_file().parent().unwrap())
957 .await
958 .unwrap();
959 fs.insert_file(
960 paths::settings_file(),
961 json!({
962 "agent": {
963 "default_model": {
964 "provider": "foo",
965 "model": "bar"
966 }
967 }
968 })
969 .to_string()
970 .into_bytes(),
971 )
972 .await;
973 let project = Project::test(fs.clone(), [], cx).await;
974
975 // Create the agent and connection
976 let agent = NativeAgent::new(
977 project.clone(),
978 Templates::new(),
979 None,
980 fs.clone(),
981 &mut cx.to_async(),
982 )
983 .await
984 .unwrap();
985 let connection = NativeAgentConnection(agent.clone());
986
987 // Create a thread/session
988 let acp_thread = cx
989 .update(|cx| {
990 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
991 })
992 .await
993 .unwrap();
994
995 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
996
997 // Select a model
998 let model_id = AgentModelId("fake/fake".into());
999 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1000 .await
1001 .unwrap();
1002
1003 // Verify the thread has the selected model
1004 agent.read_with(cx, |agent, _| {
1005 let session = agent.sessions.get(&session_id).unwrap();
1006 session.thread.read_with(cx, |thread, _| {
1007 assert_eq!(thread.model().unwrap().id().0, "fake");
1008 });
1009 });
1010
1011 cx.run_until_parked();
1012
1013 // Verify settings file was updated
1014 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1015 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1016
1017 // Check that the agent settings contain the selected model
1018 assert_eq!(
1019 settings_json["agent"]["default_model"]["model"],
1020 json!("fake")
1021 );
1022 assert_eq!(
1023 settings_json["agent"]["default_model"]["provider"],
1024 json!("fake")
1025 );
1026 }
1027
1028 fn init_test(cx: &mut TestAppContext) {
1029 env_logger::try_init().ok();
1030 cx.update(|cx| {
1031 let settings_store = SettingsStore::test(cx);
1032 cx.set_global(settings_store);
1033 Project::init_settings(cx);
1034 agent_settings::init(cx);
1035 language::init(cx);
1036 LanguageModelRegistry::test(cx);
1037 });
1038 }
1039}