1use crate::{
2 AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
3 DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
4 MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
5 ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
6};
7use acp_thread::AgentModelSelector;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::collections::HashMap;
26use std::path::Path;
27use std::rc::Rc;
28use std::sync::Arc;
29use util::ResultExt;
30
31const RULES_FILE_NAMES: [&'static str; 9] = [
32 ".rules",
33 ".cursorrules",
34 ".windsurfrules",
35 ".clinerules",
36 ".github/copilot-instructions.md",
37 "CLAUDE.md",
38 "AGENT.md",
39 "AGENTS.md",
40 "GEMINI.md",
41];
42
43pub struct RulesLoadingError {
44 pub message: SharedString,
45}
46
47/// Holds both the internal Thread and the AcpThread for a session
48struct Session {
49 /// The internal thread that processes messages
50 thread: Entity<Thread>,
51 /// The ACP thread that handles protocol communication
52 acp_thread: WeakEntity<acp_thread::AcpThread>,
53 _subscription: Subscription,
54}
55
56pub struct LanguageModels {
57 /// Access language model by ID
58 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
59 /// Cached list for returning language model information
60 model_list: acp_thread::AgentModelList,
61 refresh_models_rx: watch::Receiver<()>,
62 refresh_models_tx: watch::Sender<()>,
63}
64
65impl LanguageModels {
66 fn new(cx: &App) -> Self {
67 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
68 let mut this = Self {
69 models: HashMap::default(),
70 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
71 refresh_models_rx,
72 refresh_models_tx,
73 };
74 this.refresh_list(cx);
75 this
76 }
77
78 fn refresh_list(&mut self, cx: &App) {
79 let providers = LanguageModelRegistry::global(cx)
80 .read(cx)
81 .providers()
82 .into_iter()
83 .filter(|provider| provider.is_authenticated(cx))
84 .collect::<Vec<_>>();
85
86 let mut language_model_list = IndexMap::default();
87 let mut recommended_models = HashSet::default();
88
89 let mut recommended = Vec::new();
90 for provider in &providers {
91 for model in provider.recommended_models(cx) {
92 recommended_models.insert(model.id());
93 recommended.push(Self::map_language_model_to_info(&model, provider));
94 }
95 }
96 if !recommended.is_empty() {
97 language_model_list.insert(
98 acp_thread::AgentModelGroupName("Recommended".into()),
99 recommended,
100 );
101 }
102
103 let mut models = HashMap::default();
104 for provider in providers {
105 let mut provider_models = Vec::new();
106 for model in provider.provided_models(cx) {
107 let model_info = Self::map_language_model_to_info(&model, &provider);
108 let model_id = model_info.id.clone();
109 if !recommended_models.contains(&model.id()) {
110 provider_models.push(model_info);
111 }
112 models.insert(model_id, model);
113 }
114 if !provider_models.is_empty() {
115 language_model_list.insert(
116 acp_thread::AgentModelGroupName(provider.name().0.clone()),
117 provider_models,
118 );
119 }
120 }
121
122 self.models = models;
123 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
124 self.refresh_models_tx.send(()).ok();
125 }
126
127 fn watch(&self) -> watch::Receiver<()> {
128 self.refresh_models_rx.clone()
129 }
130
131 pub fn model_from_id(
132 &self,
133 model_id: &acp_thread::AgentModelId,
134 ) -> Option<Arc<dyn LanguageModel>> {
135 self.models.get(model_id).cloned()
136 }
137
138 fn map_language_model_to_info(
139 model: &Arc<dyn LanguageModel>,
140 provider: &Arc<dyn LanguageModelProvider>,
141 ) -> acp_thread::AgentModelInfo {
142 acp_thread::AgentModelInfo {
143 id: Self::model_id(model),
144 name: model.name().0,
145 icon: Some(provider.icon()),
146 }
147 }
148
149 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
150 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
151 }
152}
153
154pub struct NativeAgent {
155 /// Session ID -> Session mapping
156 sessions: HashMap<acp::SessionId, Session>,
157 /// Shared project context for all threads
158 project_context: Entity<ProjectContext>,
159 project_context_needs_refresh: watch::Sender<()>,
160 _maintain_project_context: Task<Result<()>>,
161 context_server_registry: Entity<ContextServerRegistry>,
162 /// Shared templates for all threads
163 templates: Arc<Templates>,
164 /// Cached model information
165 models: LanguageModels,
166 project: Entity<Project>,
167 prompt_store: Option<Entity<PromptStore>>,
168 fs: Arc<dyn Fs>,
169 _subscriptions: Vec<Subscription>,
170}
171
172impl NativeAgent {
173 pub async fn new(
174 project: Entity<Project>,
175 templates: Arc<Templates>,
176 prompt_store: Option<Entity<PromptStore>>,
177 fs: Arc<dyn Fs>,
178 cx: &mut AsyncApp,
179 ) -> Result<Entity<NativeAgent>> {
180 log::info!("Creating new NativeAgent");
181
182 let project_context = cx
183 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
184 .await;
185
186 cx.new(|cx| {
187 let mut subscriptions = vec![
188 cx.subscribe(&project, Self::handle_project_event),
189 cx.subscribe(
190 &LanguageModelRegistry::global(cx),
191 Self::handle_models_updated_event,
192 ),
193 ];
194 if let Some(prompt_store) = prompt_store.as_ref() {
195 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
196 }
197
198 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
199 watch::channel(());
200 Self {
201 sessions: HashMap::new(),
202 project_context: cx.new(|_| project_context),
203 project_context_needs_refresh: project_context_needs_refresh_tx,
204 _maintain_project_context: cx.spawn(async move |this, cx| {
205 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
206 }),
207 context_server_registry: cx.new(|cx| {
208 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
209 }),
210 templates,
211 models: LanguageModels::new(cx),
212 project,
213 prompt_store,
214 fs,
215 _subscriptions: subscriptions,
216 }
217 })
218 }
219
220 pub fn models(&self) -> &LanguageModels {
221 &self.models
222 }
223
224 async fn maintain_project_context(
225 this: WeakEntity<Self>,
226 mut needs_refresh: watch::Receiver<()>,
227 cx: &mut AsyncApp,
228 ) -> Result<()> {
229 while needs_refresh.changed().await.is_ok() {
230 let project_context = this
231 .update(cx, |this, cx| {
232 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
233 })?
234 .await;
235 this.update(cx, |this, cx| {
236 this.project_context = cx.new(|_| project_context);
237 })?;
238 }
239
240 Ok(())
241 }
242
243 fn build_project_context(
244 project: &Entity<Project>,
245 prompt_store: Option<&Entity<PromptStore>>,
246 cx: &mut App,
247 ) -> Task<ProjectContext> {
248 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
249 let worktree_tasks = worktrees
250 .into_iter()
251 .map(|worktree| {
252 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
253 })
254 .collect::<Vec<_>>();
255 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
256 prompt_store.read_with(cx, |prompt_store, cx| {
257 let prompts = prompt_store.default_prompt_metadata();
258 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
259 let contents = prompt_store.load(prompt_metadata.id, cx);
260 async move { (contents.await, prompt_metadata) }
261 });
262 cx.background_spawn(future::join_all(load_tasks))
263 })
264 } else {
265 Task::ready(vec![])
266 };
267
268 cx.spawn(async move |_cx| {
269 let (worktrees, default_user_rules) =
270 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
271
272 let worktrees = worktrees
273 .into_iter()
274 .map(|(worktree, _rules_error)| {
275 // TODO: show error message
276 // if let Some(rules_error) = rules_error {
277 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
278 // }
279 worktree
280 })
281 .collect::<Vec<_>>();
282
283 let default_user_rules = default_user_rules
284 .into_iter()
285 .flat_map(|(contents, prompt_metadata)| match contents {
286 Ok(contents) => Some(UserRulesContext {
287 uuid: match prompt_metadata.id {
288 PromptId::User { uuid } => uuid,
289 PromptId::EditWorkflow => return None,
290 },
291 title: prompt_metadata.title.map(|title| title.to_string()),
292 contents,
293 }),
294 Err(_err) => {
295 // TODO: show error message
296 // this.update(cx, |_, cx| {
297 // cx.emit(RulesLoadingError {
298 // message: format!("{err:?}").into(),
299 // });
300 // })
301 // .ok();
302 None
303 }
304 })
305 .collect::<Vec<_>>();
306
307 ProjectContext::new(worktrees, default_user_rules)
308 })
309 }
310
311 fn load_worktree_info_for_system_prompt(
312 worktree: Entity<Worktree>,
313 project: Entity<Project>,
314 cx: &mut App,
315 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
316 let tree = worktree.read(cx);
317 let root_name = tree.root_name().into();
318 let abs_path = tree.abs_path();
319
320 let mut context = WorktreeContext {
321 root_name,
322 abs_path,
323 rules_file: None,
324 };
325
326 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
327 let Some(rules_task) = rules_task else {
328 return Task::ready((context, None));
329 };
330
331 cx.spawn(async move |_| {
332 let (rules_file, rules_file_error) = match rules_task.await {
333 Ok(rules_file) => (Some(rules_file), None),
334 Err(err) => (
335 None,
336 Some(RulesLoadingError {
337 message: format!("{err}").into(),
338 }),
339 ),
340 };
341 context.rules_file = rules_file;
342 (context, rules_file_error)
343 })
344 }
345
346 fn load_worktree_rules_file(
347 worktree: Entity<Worktree>,
348 project: Entity<Project>,
349 cx: &mut App,
350 ) -> Option<Task<Result<RulesFileContext>>> {
351 let worktree = worktree.read(cx);
352 let worktree_id = worktree.id();
353 let selected_rules_file = RULES_FILE_NAMES
354 .into_iter()
355 .filter_map(|name| {
356 worktree
357 .entry_for_path(name)
358 .filter(|entry| entry.is_file())
359 .map(|entry| entry.path.clone())
360 })
361 .next();
362
363 // Note that Cline supports `.clinerules` being a directory, but that is not currently
364 // supported. This doesn't seem to occur often in GitHub repositories.
365 selected_rules_file.map(|path_in_worktree| {
366 let project_path = ProjectPath {
367 worktree_id,
368 path: path_in_worktree.clone(),
369 };
370 let buffer_task =
371 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
372 let rope_task = cx.spawn(async move |cx| {
373 buffer_task.await?.read_with(cx, |buffer, cx| {
374 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
375 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
376 })?
377 });
378 // Build a string from the rope on a background thread.
379 cx.background_spawn(async move {
380 let (project_entry_id, rope) = rope_task.await?;
381 anyhow::Ok(RulesFileContext {
382 path_in_worktree,
383 text: rope.to_string().trim().to_string(),
384 project_entry_id: project_entry_id.to_usize(),
385 })
386 })
387 })
388 }
389
390 fn handle_project_event(
391 &mut self,
392 _project: Entity<Project>,
393 event: &project::Event,
394 _cx: &mut Context<Self>,
395 ) {
396 match event {
397 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
398 self.project_context_needs_refresh.send(()).ok();
399 }
400 project::Event::WorktreeUpdatedEntries(_, items) => {
401 if items.iter().any(|(path, _, _)| {
402 RULES_FILE_NAMES
403 .iter()
404 .any(|name| path.as_ref() == Path::new(name))
405 }) {
406 self.project_context_needs_refresh.send(()).ok();
407 }
408 }
409 _ => {}
410 }
411 }
412
413 fn handle_prompts_updated_event(
414 &mut self,
415 _prompt_store: Entity<PromptStore>,
416 _event: &prompt_store::PromptsUpdatedEvent,
417 _cx: &mut Context<Self>,
418 ) {
419 self.project_context_needs_refresh.send(()).ok();
420 }
421
422 fn handle_models_updated_event(
423 &mut self,
424 _registry: Entity<LanguageModelRegistry>,
425 _event: &language_model::Event,
426 cx: &mut Context<Self>,
427 ) {
428 self.models.refresh_list(cx);
429
430 let default_model = LanguageModelRegistry::read_global(cx)
431 .default_model()
432 .map(|m| m.model.clone());
433
434 for session in self.sessions.values_mut() {
435 session.thread.update(cx, |thread, cx| {
436 if thread.model().is_none()
437 && let Some(model) = default_model.clone()
438 {
439 thread.set_model(model);
440 cx.notify();
441 }
442 });
443 }
444 }
445}
446
447/// Wrapper struct that implements the AgentConnection trait
448#[derive(Clone)]
449pub struct NativeAgentConnection(pub Entity<NativeAgent>);
450
451impl NativeAgentConnection {
452 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
453 self.0
454 .read(cx)
455 .sessions
456 .get(session_id)
457 .map(|session| session.thread.clone())
458 }
459
460 fn run_turn(
461 &self,
462 session_id: acp::SessionId,
463 cx: &mut App,
464 f: impl 'static
465 + FnOnce(
466 Entity<Thread>,
467 &mut App,
468 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
469 ) -> Task<Result<acp::PromptResponse>> {
470 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
471 agent
472 .sessions
473 .get_mut(&session_id)
474 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
475 }) else {
476 return Task::ready(Err(anyhow!("Session not found")));
477 };
478 log::debug!("Found session for: {}", session_id);
479
480 let mut response_stream = match f(thread, cx) {
481 Ok(stream) => stream,
482 Err(err) => return Task::ready(Err(err)),
483 };
484 cx.spawn(async move |cx| {
485 // Handle response stream and forward to session.acp_thread
486 while let Some(result) = response_stream.next().await {
487 match result {
488 Ok(event) => {
489 log::trace!("Received completion event: {:?}", event);
490
491 match event {
492 AgentResponseEvent::Text(text) => {
493 acp_thread.update(cx, |thread, cx| {
494 thread.push_assistant_content_block(
495 acp::ContentBlock::Text(acp::TextContent {
496 text,
497 annotations: None,
498 }),
499 false,
500 cx,
501 )
502 })?;
503 }
504 AgentResponseEvent::Thinking(text) => {
505 acp_thread.update(cx, |thread, cx| {
506 thread.push_assistant_content_block(
507 acp::ContentBlock::Text(acp::TextContent {
508 text,
509 annotations: None,
510 }),
511 true,
512 cx,
513 )
514 })?;
515 }
516 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
517 tool_call,
518 options,
519 response,
520 }) => {
521 let recv = acp_thread.update(cx, |thread, cx| {
522 thread.request_tool_call_authorization(tool_call, options, cx)
523 })?;
524 cx.background_spawn(async move {
525 if let Some(recv) = recv.log_err()
526 && let Some(option) = recv
527 .await
528 .context("authorization sender was dropped")
529 .log_err()
530 {
531 response
532 .send(option)
533 .map(|_| anyhow!("authorization receiver was dropped"))
534 .log_err();
535 }
536 })
537 .detach();
538 }
539 AgentResponseEvent::ToolCall(tool_call) => {
540 acp_thread.update(cx, |thread, cx| {
541 thread.upsert_tool_call(tool_call, cx)
542 })??;
543 }
544 AgentResponseEvent::ToolCallUpdate(update) => {
545 acp_thread.update(cx, |thread, cx| {
546 thread.update_tool_call(update, cx)
547 })??;
548 }
549 AgentResponseEvent::Retry(status) => {
550 acp_thread.update(cx, |thread, cx| {
551 thread.update_retry_status(status, cx)
552 })?;
553 }
554 AgentResponseEvent::Stop(stop_reason) => {
555 log::debug!("Assistant message complete: {:?}", stop_reason);
556 return Ok(acp::PromptResponse { stop_reason });
557 }
558 }
559 }
560 Err(e) => {
561 log::error!("Error in model response stream: {:?}", e);
562 return Err(e);
563 }
564 }
565 }
566
567 log::info!("Response stream completed");
568 anyhow::Ok(acp::PromptResponse {
569 stop_reason: acp::StopReason::EndTurn,
570 })
571 })
572 }
573}
574
575impl AgentModelSelector for NativeAgentConnection {
576 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
577 log::debug!("NativeAgentConnection::list_models called");
578 let list = self.0.read(cx).models.model_list.clone();
579 Task::ready(if list.is_empty() {
580 Err(anyhow::anyhow!("No models available"))
581 } else {
582 Ok(list)
583 })
584 }
585
586 fn select_model(
587 &self,
588 session_id: acp::SessionId,
589 model_id: acp_thread::AgentModelId,
590 cx: &mut App,
591 ) -> Task<Result<()>> {
592 log::info!("Setting model for session {}: {}", session_id, model_id);
593 let Some(thread) = self
594 .0
595 .read(cx)
596 .sessions
597 .get(&session_id)
598 .map(|session| session.thread.clone())
599 else {
600 return Task::ready(Err(anyhow!("Session not found")));
601 };
602
603 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
604 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
605 };
606
607 thread.update(cx, |thread, _cx| {
608 thread.set_model(model.clone());
609 });
610
611 update_settings_file::<AgentSettings>(
612 self.0.read(cx).fs.clone(),
613 cx,
614 move |settings, _cx| {
615 settings.set_model(model);
616 },
617 );
618
619 Task::ready(Ok(()))
620 }
621
622 fn selected_model(
623 &self,
624 session_id: &acp::SessionId,
625 cx: &mut App,
626 ) -> Task<Result<acp_thread::AgentModelInfo>> {
627 let session_id = session_id.clone();
628
629 let Some(thread) = self
630 .0
631 .read(cx)
632 .sessions
633 .get(&session_id)
634 .map(|session| session.thread.clone())
635 else {
636 return Task::ready(Err(anyhow!("Session not found")));
637 };
638 let Some(model) = thread.read(cx).model() else {
639 return Task::ready(Err(anyhow!("Model not found")));
640 };
641 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
642 else {
643 return Task::ready(Err(anyhow!("Provider not found")));
644 };
645 Task::ready(Ok(LanguageModels::map_language_model_to_info(
646 model, &provider,
647 )))
648 }
649
650 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
651 self.0.read(cx).models.watch()
652 }
653}
654
655impl acp_thread::AgentConnection for NativeAgentConnection {
656 fn new_thread(
657 self: Rc<Self>,
658 project: Entity<Project>,
659 cwd: &Path,
660 cx: &mut App,
661 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
662 let agent = self.0.clone();
663 log::info!("Creating new thread for project at: {:?}", cwd);
664
665 cx.spawn(async move |cx| {
666 log::debug!("Starting thread creation in async context");
667
668 // Generate session ID
669 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
670 log::info!("Created session with ID: {}", session_id);
671
672 // Create AcpThread
673 let acp_thread = cx.update(|cx| {
674 cx.new(|cx| {
675 acp_thread::AcpThread::new(
676 "agent2",
677 self.clone(),
678 project.clone(),
679 session_id.clone(),
680 cx,
681 )
682 })
683 })?;
684 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
685
686 // Create Thread
687 let thread = agent.update(
688 cx,
689 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
690 // Fetch default model from registry settings
691 let registry = LanguageModelRegistry::read_global(cx);
692
693 // Log available models for debugging
694 let available_count = registry.available_models(cx).count();
695 log::debug!("Total available models: {}", available_count);
696
697 let default_model = registry.default_model().and_then(|default_model| {
698 agent
699 .models
700 .model_from_id(&LanguageModels::model_id(&default_model.model))
701 });
702
703 let thread = cx.new(|cx| {
704 let mut thread = Thread::new(
705 project.clone(),
706 agent.project_context.clone(),
707 agent.context_server_registry.clone(),
708 action_log.clone(),
709 agent.templates.clone(),
710 default_model,
711 cx,
712 );
713 thread.add_tool(CopyPathTool::new(project.clone()));
714 thread.add_tool(CreateDirectoryTool::new(project.clone()));
715 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
716 thread.add_tool(DiagnosticsTool::new(project.clone()));
717 thread.add_tool(EditFileTool::new(cx.entity()));
718 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
719 thread.add_tool(FindPathTool::new(project.clone()));
720 thread.add_tool(GrepTool::new(project.clone()));
721 thread.add_tool(ListDirectoryTool::new(project.clone()));
722 thread.add_tool(MovePathTool::new(project.clone()));
723 thread.add_tool(NowTool);
724 thread.add_tool(OpenTool::new(project.clone()));
725 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
726 thread.add_tool(TerminalTool::new(project.clone(), cx));
727 thread.add_tool(ThinkingTool);
728 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
729 thread
730 });
731
732 Ok(thread)
733 },
734 )??;
735
736 // Store the session
737 agent.update(cx, |agent, cx| {
738 agent.sessions.insert(
739 session_id,
740 Session {
741 thread,
742 acp_thread: acp_thread.downgrade(),
743 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
744 this.sessions.remove(acp_thread.session_id());
745 }),
746 },
747 );
748 })?;
749
750 Ok(acp_thread)
751 })
752 }
753
754 fn auth_methods(&self) -> &[acp::AuthMethod] {
755 &[] // No auth for in-process
756 }
757
758 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
759 Task::ready(Ok(()))
760 }
761
762 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
763 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
764 }
765
766 fn prompt(
767 &self,
768 id: Option<acp_thread::UserMessageId>,
769 params: acp::PromptRequest,
770 cx: &mut App,
771 ) -> Task<Result<acp::PromptResponse>> {
772 let id = id.expect("UserMessageId is required");
773 let session_id = params.session_id.clone();
774 log::info!("Received prompt request for session: {}", session_id);
775 log::debug!("Prompt blocks count: {}", params.prompt.len());
776
777 self.run_turn(session_id, cx, |thread, cx| {
778 let content: Vec<UserMessageContent> = params
779 .prompt
780 .into_iter()
781 .map(Into::into)
782 .collect::<Vec<_>>();
783 log::info!("Converted prompt to message: {} chars", content.len());
784 log::debug!("Message id: {:?}", id);
785 log::debug!("Message content: {:?}", content);
786
787 thread.update(cx, |thread, cx| thread.send(id, content, cx))
788 })
789 }
790
791 fn resume(
792 &self,
793 session_id: &acp::SessionId,
794 _cx: &mut App,
795 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
796 Some(Rc::new(NativeAgentSessionResume {
797 connection: self.clone(),
798 session_id: session_id.clone(),
799 }) as _)
800 }
801
802 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
803 log::info!("Cancelling on session: {}", session_id);
804 self.0.update(cx, |agent, cx| {
805 if let Some(agent) = agent.sessions.get(session_id) {
806 agent.thread.update(cx, |thread, _cx| thread.cancel());
807 }
808 });
809 }
810
811 fn session_editor(
812 &self,
813 session_id: &agent_client_protocol::SessionId,
814 cx: &mut App,
815 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
816 self.0.update(cx, |agent, _cx| {
817 agent
818 .sessions
819 .get(session_id)
820 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
821 })
822 }
823
824 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
825 self
826 }
827}
828
829struct NativeAgentSessionEditor(Entity<Thread>);
830
831impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
832 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
833 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
834 }
835}
836
837struct NativeAgentSessionResume {
838 connection: NativeAgentConnection,
839 session_id: acp::SessionId,
840}
841
842impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
843 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
844 self.connection
845 .run_turn(self.session_id.clone(), cx, |thread, cx| {
846 thread.update(cx, |thread, cx| thread.resume(cx))
847 })
848 }
849}
850
851#[cfg(test)]
852mod tests {
853 use super::*;
854 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
855 use fs::FakeFs;
856 use gpui::TestAppContext;
857 use serde_json::json;
858 use settings::SettingsStore;
859
860 #[gpui::test]
861 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
862 init_test(cx);
863 let fs = FakeFs::new(cx.executor());
864 fs.insert_tree(
865 "/",
866 json!({
867 "a": {}
868 }),
869 )
870 .await;
871 let project = Project::test(fs.clone(), [], cx).await;
872 let agent = NativeAgent::new(
873 project.clone(),
874 Templates::new(),
875 None,
876 fs.clone(),
877 &mut cx.to_async(),
878 )
879 .await
880 .unwrap();
881 agent.read_with(cx, |agent, cx| {
882 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
883 });
884
885 let worktree = project
886 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
887 .await
888 .unwrap();
889 cx.run_until_parked();
890 agent.read_with(cx, |agent, cx| {
891 assert_eq!(
892 agent.project_context.read(cx).worktrees,
893 vec![WorktreeContext {
894 root_name: "a".into(),
895 abs_path: Path::new("/a").into(),
896 rules_file: None
897 }]
898 )
899 });
900
901 // Creating `/a/.rules` updates the project context.
902 fs.insert_file("/a/.rules", Vec::new()).await;
903 cx.run_until_parked();
904 agent.read_with(cx, |agent, cx| {
905 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
906 assert_eq!(
907 agent.project_context.read(cx).worktrees,
908 vec![WorktreeContext {
909 root_name: "a".into(),
910 abs_path: Path::new("/a").into(),
911 rules_file: Some(RulesFileContext {
912 path_in_worktree: Path::new(".rules").into(),
913 text: "".into(),
914 project_entry_id: rules_entry.id.to_usize()
915 })
916 }]
917 )
918 });
919 }
920
921 #[gpui::test]
922 async fn test_listing_models(cx: &mut TestAppContext) {
923 init_test(cx);
924 let fs = FakeFs::new(cx.executor());
925 fs.insert_tree("/", json!({ "a": {} })).await;
926 let project = Project::test(fs.clone(), [], cx).await;
927 let connection = NativeAgentConnection(
928 NativeAgent::new(
929 project.clone(),
930 Templates::new(),
931 None,
932 fs.clone(),
933 &mut cx.to_async(),
934 )
935 .await
936 .unwrap(),
937 );
938
939 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
940
941 let acp_thread::AgentModelList::Grouped(models) = models else {
942 panic!("Unexpected model group");
943 };
944 assert_eq!(
945 models,
946 IndexMap::from_iter([(
947 AgentModelGroupName("Fake".into()),
948 vec![AgentModelInfo {
949 id: AgentModelId("fake/fake".into()),
950 name: "Fake".into(),
951 icon: Some(ui::IconName::ZedAssistant),
952 }]
953 )])
954 );
955 }
956
957 #[gpui::test]
958 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
959 init_test(cx);
960 let fs = FakeFs::new(cx.executor());
961 fs.create_dir(paths::settings_file().parent().unwrap())
962 .await
963 .unwrap();
964 fs.insert_file(
965 paths::settings_file(),
966 json!({
967 "agent": {
968 "default_model": {
969 "provider": "foo",
970 "model": "bar"
971 }
972 }
973 })
974 .to_string()
975 .into_bytes(),
976 )
977 .await;
978 let project = Project::test(fs.clone(), [], cx).await;
979
980 // Create the agent and connection
981 let agent = NativeAgent::new(
982 project.clone(),
983 Templates::new(),
984 None,
985 fs.clone(),
986 &mut cx.to_async(),
987 )
988 .await
989 .unwrap();
990 let connection = NativeAgentConnection(agent.clone());
991
992 // Create a thread/session
993 let acp_thread = cx
994 .update(|cx| {
995 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
996 })
997 .await
998 .unwrap();
999
1000 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1001
1002 // Select a model
1003 let model_id = AgentModelId("fake/fake".into());
1004 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1005 .await
1006 .unwrap();
1007
1008 // Verify the thread has the selected model
1009 agent.read_with(cx, |agent, _| {
1010 let session = agent.sessions.get(&session_id).unwrap();
1011 session.thread.read_with(cx, |thread, _| {
1012 assert_eq!(thread.model().unwrap().id().0, "fake");
1013 });
1014 });
1015
1016 cx.run_until_parked();
1017
1018 // Verify settings file was updated
1019 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1020 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1021
1022 // Check that the agent settings contain the selected model
1023 assert_eq!(
1024 settings_json["agent"]["default_model"]["model"],
1025 json!("fake")
1026 );
1027 assert_eq!(
1028 settings_json["agent"]["default_model"]["provider"],
1029 json!("fake")
1030 );
1031 }
1032
1033 fn init_test(cx: &mut TestAppContext) {
1034 env_logger::try_init().ok();
1035 cx.update(|cx| {
1036 let settings_store = SettingsStore::test(cx);
1037 cx.set_global(settings_store);
1038 Project::init_settings(cx);
1039 agent_settings::init(cx);
1040 language::init(cx);
1041 LanguageModelRegistry::test(cx);
1042 });
1043 }
1044}