1use crate::{
2 AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
3 DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
4 MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
5 ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
6};
7use acp_thread::AgentModelSelector;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::cell::RefCell;
26use std::collections::HashMap;
27use std::path::Path;
28use std::rc::Rc;
29use std::sync::Arc;
30use util::ResultExt;
31
32const RULES_FILE_NAMES: [&'static str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 _subscription: Subscription,
55}
56
57pub struct LanguageModels {
58 /// Access language model by ID
59 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
60 /// Cached list for returning language model information
61 model_list: acp_thread::AgentModelList,
62 refresh_models_rx: watch::Receiver<()>,
63 refresh_models_tx: watch::Sender<()>,
64}
65
66impl LanguageModels {
67 fn new(cx: &App) -> Self {
68 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
69 let mut this = Self {
70 models: HashMap::default(),
71 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
72 refresh_models_rx,
73 refresh_models_tx,
74 };
75 this.refresh_list(cx);
76 this
77 }
78
79 fn refresh_list(&mut self, cx: &App) {
80 let providers = LanguageModelRegistry::global(cx)
81 .read(cx)
82 .providers()
83 .into_iter()
84 .filter(|provider| provider.is_authenticated(cx))
85 .collect::<Vec<_>>();
86
87 let mut language_model_list = IndexMap::default();
88 let mut recommended_models = HashSet::default();
89
90 let mut recommended = Vec::new();
91 for provider in &providers {
92 for model in provider.recommended_models(cx) {
93 recommended_models.insert(model.id());
94 recommended.push(Self::map_language_model_to_info(&model, &provider));
95 }
96 }
97 if !recommended.is_empty() {
98 language_model_list.insert(
99 acp_thread::AgentModelGroupName("Recommended".into()),
100 recommended,
101 );
102 }
103
104 let mut models = HashMap::default();
105 for provider in providers {
106 let mut provider_models = Vec::new();
107 for model in provider.provided_models(cx) {
108 let model_info = Self::map_language_model_to_info(&model, &provider);
109 let model_id = model_info.id.clone();
110 if !recommended_models.contains(&model.id()) {
111 provider_models.push(model_info);
112 }
113 models.insert(model_id, model);
114 }
115 if !provider_models.is_empty() {
116 language_model_list.insert(
117 acp_thread::AgentModelGroupName(provider.name().0.clone()),
118 provider_models,
119 );
120 }
121 }
122
123 self.models = models;
124 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
125 self.refresh_models_tx.send(()).ok();
126 }
127
128 fn watch(&self) -> watch::Receiver<()> {
129 self.refresh_models_rx.clone()
130 }
131
132 pub fn model_from_id(
133 &self,
134 model_id: &acp_thread::AgentModelId,
135 ) -> Option<Arc<dyn LanguageModel>> {
136 self.models.get(model_id).cloned()
137 }
138
139 fn map_language_model_to_info(
140 model: &Arc<dyn LanguageModel>,
141 provider: &Arc<dyn LanguageModelProvider>,
142 ) -> acp_thread::AgentModelInfo {
143 acp_thread::AgentModelInfo {
144 id: Self::model_id(model),
145 name: model.name().0,
146 icon: Some(provider.icon()),
147 }
148 }
149
150 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
151 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
152 }
153}
154
155pub struct NativeAgent {
156 /// Session ID -> Session mapping
157 sessions: HashMap<acp::SessionId, Session>,
158 /// Shared project context for all threads
159 project_context: Rc<RefCell<ProjectContext>>,
160 project_context_needs_refresh: watch::Sender<()>,
161 _maintain_project_context: Task<Result<()>>,
162 context_server_registry: Entity<ContextServerRegistry>,
163 /// Shared templates for all threads
164 templates: Arc<Templates>,
165 /// Cached model information
166 models: LanguageModels,
167 project: Entity<Project>,
168 prompt_store: Option<Entity<PromptStore>>,
169 fs: Arc<dyn Fs>,
170 _subscriptions: Vec<Subscription>,
171}
172
173impl NativeAgent {
174 pub async fn new(
175 project: Entity<Project>,
176 templates: Arc<Templates>,
177 prompt_store: Option<Entity<PromptStore>>,
178 fs: Arc<dyn Fs>,
179 cx: &mut AsyncApp,
180 ) -> Result<Entity<NativeAgent>> {
181 log::info!("Creating new NativeAgent");
182
183 let project_context = cx
184 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
185 .await;
186
187 cx.new(|cx| {
188 let mut subscriptions = vec![
189 cx.subscribe(&project, Self::handle_project_event),
190 cx.subscribe(
191 &LanguageModelRegistry::global(cx),
192 Self::handle_models_updated_event,
193 ),
194 ];
195 if let Some(prompt_store) = prompt_store.as_ref() {
196 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
197 }
198
199 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
200 watch::channel(());
201 Self {
202 sessions: HashMap::new(),
203 project_context: Rc::new(RefCell::new(project_context)),
204 project_context_needs_refresh: project_context_needs_refresh_tx,
205 _maintain_project_context: cx.spawn(async move |this, cx| {
206 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
207 }),
208 context_server_registry: cx.new(|cx| {
209 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
210 }),
211 templates,
212 models: LanguageModels::new(cx),
213 project,
214 prompt_store,
215 fs,
216 _subscriptions: subscriptions,
217 }
218 })
219 }
220
221 pub fn models(&self) -> &LanguageModels {
222 &self.models
223 }
224
225 async fn maintain_project_context(
226 this: WeakEntity<Self>,
227 mut needs_refresh: watch::Receiver<()>,
228 cx: &mut AsyncApp,
229 ) -> Result<()> {
230 while needs_refresh.changed().await.is_ok() {
231 let project_context = this
232 .update(cx, |this, cx| {
233 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
234 })?
235 .await;
236 this.update(cx, |this, _| this.project_context.replace(project_context))?;
237 }
238
239 Ok(())
240 }
241
242 fn build_project_context(
243 project: &Entity<Project>,
244 prompt_store: Option<&Entity<PromptStore>>,
245 cx: &mut App,
246 ) -> Task<ProjectContext> {
247 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
248 let worktree_tasks = worktrees
249 .into_iter()
250 .map(|worktree| {
251 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
252 })
253 .collect::<Vec<_>>();
254 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
255 prompt_store.read_with(cx, |prompt_store, cx| {
256 let prompts = prompt_store.default_prompt_metadata();
257 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
258 let contents = prompt_store.load(prompt_metadata.id, cx);
259 async move { (contents.await, prompt_metadata) }
260 });
261 cx.background_spawn(future::join_all(load_tasks))
262 })
263 } else {
264 Task::ready(vec![])
265 };
266
267 cx.spawn(async move |_cx| {
268 let (worktrees, default_user_rules) =
269 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
270
271 let worktrees = worktrees
272 .into_iter()
273 .map(|(worktree, _rules_error)| {
274 // TODO: show error message
275 // if let Some(rules_error) = rules_error {
276 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
277 // }
278 worktree
279 })
280 .collect::<Vec<_>>();
281
282 let default_user_rules = default_user_rules
283 .into_iter()
284 .flat_map(|(contents, prompt_metadata)| match contents {
285 Ok(contents) => Some(UserRulesContext {
286 uuid: match prompt_metadata.id {
287 PromptId::User { uuid } => uuid,
288 PromptId::EditWorkflow => return None,
289 },
290 title: prompt_metadata.title.map(|title| title.to_string()),
291 contents,
292 }),
293 Err(_err) => {
294 // TODO: show error message
295 // this.update(cx, |_, cx| {
296 // cx.emit(RulesLoadingError {
297 // message: format!("{err:?}").into(),
298 // });
299 // })
300 // .ok();
301 None
302 }
303 })
304 .collect::<Vec<_>>();
305
306 ProjectContext::new(worktrees, default_user_rules)
307 })
308 }
309
310 fn load_worktree_info_for_system_prompt(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
315 let tree = worktree.read(cx);
316 let root_name = tree.root_name().into();
317 let abs_path = tree.abs_path();
318
319 let mut context = WorktreeContext {
320 root_name,
321 abs_path,
322 rules_file: None,
323 };
324
325 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
326 let Some(rules_task) = rules_task else {
327 return Task::ready((context, None));
328 };
329
330 cx.spawn(async move |_| {
331 let (rules_file, rules_file_error) = match rules_task.await {
332 Ok(rules_file) => (Some(rules_file), None),
333 Err(err) => (
334 None,
335 Some(RulesLoadingError {
336 message: format!("{err}").into(),
337 }),
338 ),
339 };
340 context.rules_file = rules_file;
341 (context, rules_file_error)
342 })
343 }
344
345 fn load_worktree_rules_file(
346 worktree: Entity<Worktree>,
347 project: Entity<Project>,
348 cx: &mut App,
349 ) -> Option<Task<Result<RulesFileContext>>> {
350 let worktree = worktree.read(cx);
351 let worktree_id = worktree.id();
352 let selected_rules_file = RULES_FILE_NAMES
353 .into_iter()
354 .filter_map(|name| {
355 worktree
356 .entry_for_path(name)
357 .filter(|entry| entry.is_file())
358 .map(|entry| entry.path.clone())
359 })
360 .next();
361
362 // Note that Cline supports `.clinerules` being a directory, but that is not currently
363 // supported. This doesn't seem to occur often in GitHub repositories.
364 selected_rules_file.map(|path_in_worktree| {
365 let project_path = ProjectPath {
366 worktree_id,
367 path: path_in_worktree.clone(),
368 };
369 let buffer_task =
370 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
371 let rope_task = cx.spawn(async move |cx| {
372 buffer_task.await?.read_with(cx, |buffer, cx| {
373 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
374 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
375 })?
376 });
377 // Build a string from the rope on a background thread.
378 cx.background_spawn(async move {
379 let (project_entry_id, rope) = rope_task.await?;
380 anyhow::Ok(RulesFileContext {
381 path_in_worktree,
382 text: rope.to_string().trim().to_string(),
383 project_entry_id: project_entry_id.to_usize(),
384 })
385 })
386 })
387 }
388
389 fn handle_project_event(
390 &mut self,
391 _project: Entity<Project>,
392 event: &project::Event,
393 _cx: &mut Context<Self>,
394 ) {
395 match event {
396 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
397 self.project_context_needs_refresh.send(()).ok();
398 }
399 project::Event::WorktreeUpdatedEntries(_, items) => {
400 if items.iter().any(|(path, _, _)| {
401 RULES_FILE_NAMES
402 .iter()
403 .any(|name| path.as_ref() == Path::new(name))
404 }) {
405 self.project_context_needs_refresh.send(()).ok();
406 }
407 }
408 _ => {}
409 }
410 }
411
412 fn handle_prompts_updated_event(
413 &mut self,
414 _prompt_store: Entity<PromptStore>,
415 _event: &prompt_store::PromptsUpdatedEvent,
416 _cx: &mut Context<Self>,
417 ) {
418 self.project_context_needs_refresh.send(()).ok();
419 }
420
421 fn handle_models_updated_event(
422 &mut self,
423 _registry: Entity<LanguageModelRegistry>,
424 _event: &language_model::Event,
425 cx: &mut Context<Self>,
426 ) {
427 self.models.refresh_list(cx);
428 for session in self.sessions.values_mut() {
429 session.thread.update(cx, |thread, _| {
430 if let Some(model) = thread.model() {
431 let model_id = LanguageModels::model_id(model);
432 if let Some(model) = self.models.model_from_id(&model_id) {
433 thread.set_model(model.clone());
434 }
435 }
436 });
437 }
438 }
439}
440
441/// Wrapper struct that implements the AgentConnection trait
442#[derive(Clone)]
443pub struct NativeAgentConnection(pub Entity<NativeAgent>);
444
445impl NativeAgentConnection {
446 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
447 self.0
448 .read(cx)
449 .sessions
450 .get(session_id)
451 .map(|session| session.thread.clone())
452 }
453
454 fn run_turn(
455 &self,
456 session_id: acp::SessionId,
457 cx: &mut App,
458 f: impl 'static
459 + FnOnce(
460 Entity<Thread>,
461 &mut App,
462 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
463 ) -> Task<Result<acp::PromptResponse>> {
464 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
465 agent
466 .sessions
467 .get_mut(&session_id)
468 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
469 }) else {
470 return Task::ready(Err(anyhow!("Session not found")));
471 };
472 log::debug!("Found session for: {}", session_id);
473
474 let mut response_stream = match f(thread, cx) {
475 Ok(stream) => stream,
476 Err(err) => return Task::ready(Err(err)),
477 };
478 cx.spawn(async move |cx| {
479 // Handle response stream and forward to session.acp_thread
480 while let Some(result) = response_stream.next().await {
481 match result {
482 Ok(event) => {
483 log::trace!("Received completion event: {:?}", event);
484
485 match event {
486 AgentResponseEvent::Text(text) => {
487 acp_thread.update(cx, |thread, cx| {
488 thread.push_assistant_content_block(
489 acp::ContentBlock::Text(acp::TextContent {
490 text,
491 annotations: None,
492 }),
493 false,
494 cx,
495 )
496 })?;
497 }
498 AgentResponseEvent::Thinking(text) => {
499 acp_thread.update(cx, |thread, cx| {
500 thread.push_assistant_content_block(
501 acp::ContentBlock::Text(acp::TextContent {
502 text,
503 annotations: None,
504 }),
505 true,
506 cx,
507 )
508 })?;
509 }
510 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
511 tool_call,
512 options,
513 response,
514 }) => {
515 let recv = acp_thread.update(cx, |thread, cx| {
516 thread.request_tool_call_authorization(tool_call, options, cx)
517 })?;
518 cx.background_spawn(async move {
519 if let Some(recv) = recv.log_err()
520 && let Some(option) = recv
521 .await
522 .context("authorization sender was dropped")
523 .log_err()
524 {
525 response
526 .send(option)
527 .map(|_| anyhow!("authorization receiver was dropped"))
528 .log_err();
529 }
530 })
531 .detach();
532 }
533 AgentResponseEvent::ToolCall(tool_call) => {
534 acp_thread.update(cx, |thread, cx| {
535 thread.upsert_tool_call(tool_call, cx)
536 })??;
537 }
538 AgentResponseEvent::ToolCallUpdate(update) => {
539 acp_thread.update(cx, |thread, cx| {
540 thread.update_tool_call(update, cx)
541 })??;
542 }
543 AgentResponseEvent::Stop(stop_reason) => {
544 log::debug!("Assistant message complete: {:?}", stop_reason);
545 return Ok(acp::PromptResponse { stop_reason });
546 }
547 }
548 }
549 Err(e) => {
550 log::error!("Error in model response stream: {:?}", e);
551 return Err(e);
552 }
553 }
554 }
555
556 log::info!("Response stream completed");
557 anyhow::Ok(acp::PromptResponse {
558 stop_reason: acp::StopReason::EndTurn,
559 })
560 })
561 }
562}
563
564impl AgentModelSelector for NativeAgentConnection {
565 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
566 log::debug!("NativeAgentConnection::list_models called");
567 let list = self.0.read(cx).models.model_list.clone();
568 Task::ready(if list.is_empty() {
569 Err(anyhow::anyhow!("No models available"))
570 } else {
571 Ok(list)
572 })
573 }
574
575 fn select_model(
576 &self,
577 session_id: acp::SessionId,
578 model_id: acp_thread::AgentModelId,
579 cx: &mut App,
580 ) -> Task<Result<()>> {
581 log::info!("Setting model for session {}: {}", session_id, model_id);
582 let Some(thread) = self
583 .0
584 .read(cx)
585 .sessions
586 .get(&session_id)
587 .map(|session| session.thread.clone())
588 else {
589 return Task::ready(Err(anyhow!("Session not found")));
590 };
591
592 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
593 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
594 };
595
596 thread.update(cx, |thread, _cx| {
597 thread.set_model(model.clone());
598 });
599
600 update_settings_file::<AgentSettings>(
601 self.0.read(cx).fs.clone(),
602 cx,
603 move |settings, _cx| {
604 settings.set_model(model);
605 },
606 );
607
608 Task::ready(Ok(()))
609 }
610
611 fn selected_model(
612 &self,
613 session_id: &acp::SessionId,
614 cx: &mut App,
615 ) -> Task<Result<acp_thread::AgentModelInfo>> {
616 let session_id = session_id.clone();
617
618 let Some(thread) = self
619 .0
620 .read(cx)
621 .sessions
622 .get(&session_id)
623 .map(|session| session.thread.clone())
624 else {
625 return Task::ready(Err(anyhow!("Session not found")));
626 };
627 let Some(model) = thread.read(cx).model() else {
628 return Task::ready(Err(anyhow!("Model not found")));
629 };
630 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
631 else {
632 return Task::ready(Err(anyhow!("Provider not found")));
633 };
634 Task::ready(Ok(LanguageModels::map_language_model_to_info(
635 model, &provider,
636 )))
637 }
638
639 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
640 self.0.read(cx).models.watch()
641 }
642}
643
644impl acp_thread::AgentConnection for NativeAgentConnection {
645 fn new_thread(
646 self: Rc<Self>,
647 project: Entity<Project>,
648 cwd: &Path,
649 cx: &mut App,
650 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
651 let agent = self.0.clone();
652 log::info!("Creating new thread for project at: {:?}", cwd);
653
654 cx.spawn(async move |cx| {
655 log::debug!("Starting thread creation in async context");
656
657 // Generate session ID
658 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
659 log::info!("Created session with ID: {}", session_id);
660
661 // Create AcpThread
662 let acp_thread = cx.update(|cx| {
663 cx.new(|cx| {
664 acp_thread::AcpThread::new(
665 "agent2",
666 self.clone(),
667 project.clone(),
668 session_id.clone(),
669 cx,
670 )
671 })
672 })?;
673 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
674
675 // Create Thread
676 let thread = agent.update(
677 cx,
678 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
679 // Fetch default model from registry settings
680 let registry = LanguageModelRegistry::read_global(cx);
681
682 // Log available models for debugging
683 let available_count = registry.available_models(cx).count();
684 log::debug!("Total available models: {}", available_count);
685
686 let default_model = registry.default_model().and_then(|default_model| {
687 agent
688 .models
689 .model_from_id(&LanguageModels::model_id(&default_model.model))
690 });
691
692 let thread = cx.new(|cx| {
693 let mut thread = Thread::new(
694 project.clone(),
695 agent.project_context.clone(),
696 agent.context_server_registry.clone(),
697 action_log.clone(),
698 agent.templates.clone(),
699 default_model,
700 cx,
701 );
702 thread.add_tool(CopyPathTool::new(project.clone()));
703 thread.add_tool(CreateDirectoryTool::new(project.clone()));
704 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
705 thread.add_tool(DiagnosticsTool::new(project.clone()));
706 thread.add_tool(EditFileTool::new(cx.entity()));
707 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
708 thread.add_tool(FindPathTool::new(project.clone()));
709 thread.add_tool(GrepTool::new(project.clone()));
710 thread.add_tool(ListDirectoryTool::new(project.clone()));
711 thread.add_tool(MovePathTool::new(project.clone()));
712 thread.add_tool(NowTool);
713 thread.add_tool(OpenTool::new(project.clone()));
714 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
715 thread.add_tool(TerminalTool::new(project.clone(), cx));
716 thread.add_tool(ThinkingTool);
717 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
718 thread
719 });
720
721 Ok(thread)
722 },
723 )??;
724
725 // Store the session
726 agent.update(cx, |agent, cx| {
727 agent.sessions.insert(
728 session_id,
729 Session {
730 thread,
731 acp_thread: acp_thread.downgrade(),
732 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
733 this.sessions.remove(acp_thread.session_id());
734 }),
735 },
736 );
737 })?;
738
739 Ok(acp_thread)
740 })
741 }
742
743 fn auth_methods(&self) -> &[acp::AuthMethod] {
744 &[] // No auth for in-process
745 }
746
747 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
748 Task::ready(Ok(()))
749 }
750
751 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
752 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
753 }
754
755 fn prompt(
756 &self,
757 id: Option<acp_thread::UserMessageId>,
758 params: acp::PromptRequest,
759 cx: &mut App,
760 ) -> Task<Result<acp::PromptResponse>> {
761 let id = id.expect("UserMessageId is required");
762 let session_id = params.session_id.clone();
763 log::info!("Received prompt request for session: {}", session_id);
764 log::debug!("Prompt blocks count: {}", params.prompt.len());
765
766 self.run_turn(session_id, cx, |thread, cx| {
767 let content: Vec<UserMessageContent> = params
768 .prompt
769 .into_iter()
770 .map(Into::into)
771 .collect::<Vec<_>>();
772 log::info!("Converted prompt to message: {} chars", content.len());
773 log::debug!("Message id: {:?}", id);
774 log::debug!("Message content: {:?}", content);
775
776 thread.update(cx, |thread, cx| thread.send(id, content, cx))
777 })
778 }
779
780 fn resume(
781 &self,
782 session_id: &acp::SessionId,
783 _cx: &mut App,
784 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
785 Some(Rc::new(NativeAgentSessionResume {
786 connection: self.clone(),
787 session_id: session_id.clone(),
788 }) as _)
789 }
790
791 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
792 log::info!("Cancelling on session: {}", session_id);
793 self.0.update(cx, |agent, cx| {
794 if let Some(agent) = agent.sessions.get(session_id) {
795 agent.thread.update(cx, |thread, _cx| thread.cancel());
796 }
797 });
798 }
799
800 fn session_editor(
801 &self,
802 session_id: &agent_client_protocol::SessionId,
803 cx: &mut App,
804 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
805 self.0.update(cx, |agent, _cx| {
806 agent
807 .sessions
808 .get(session_id)
809 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
810 })
811 }
812
813 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
814 self
815 }
816}
817
818struct NativeAgentSessionEditor(Entity<Thread>);
819
820impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
821 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
822 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
823 }
824}
825
826struct NativeAgentSessionResume {
827 connection: NativeAgentConnection,
828 session_id: acp::SessionId,
829}
830
831impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
832 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
833 self.connection
834 .run_turn(self.session_id.clone(), cx, |thread, cx| {
835 thread.update(cx, |thread, cx| thread.resume(cx))
836 })
837 }
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
844 use fs::FakeFs;
845 use gpui::TestAppContext;
846 use serde_json::json;
847 use settings::SettingsStore;
848
849 #[gpui::test]
850 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
851 init_test(cx);
852 let fs = FakeFs::new(cx.executor());
853 fs.insert_tree(
854 "/",
855 json!({
856 "a": {}
857 }),
858 )
859 .await;
860 let project = Project::test(fs.clone(), [], cx).await;
861 let agent = NativeAgent::new(
862 project.clone(),
863 Templates::new(),
864 None,
865 fs.clone(),
866 &mut cx.to_async(),
867 )
868 .await
869 .unwrap();
870 agent.read_with(cx, |agent, _| {
871 assert_eq!(agent.project_context.borrow().worktrees, vec![])
872 });
873
874 let worktree = project
875 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
876 .await
877 .unwrap();
878 cx.run_until_parked();
879 agent.read_with(cx, |agent, _| {
880 assert_eq!(
881 agent.project_context.borrow().worktrees,
882 vec![WorktreeContext {
883 root_name: "a".into(),
884 abs_path: Path::new("/a").into(),
885 rules_file: None
886 }]
887 )
888 });
889
890 // Creating `/a/.rules` updates the project context.
891 fs.insert_file("/a/.rules", Vec::new()).await;
892 cx.run_until_parked();
893 agent.read_with(cx, |agent, cx| {
894 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
895 assert_eq!(
896 agent.project_context.borrow().worktrees,
897 vec![WorktreeContext {
898 root_name: "a".into(),
899 abs_path: Path::new("/a").into(),
900 rules_file: Some(RulesFileContext {
901 path_in_worktree: Path::new(".rules").into(),
902 text: "".into(),
903 project_entry_id: rules_entry.id.to_usize()
904 })
905 }]
906 )
907 });
908 }
909
910 #[gpui::test]
911 async fn test_listing_models(cx: &mut TestAppContext) {
912 init_test(cx);
913 let fs = FakeFs::new(cx.executor());
914 fs.insert_tree("/", json!({ "a": {} })).await;
915 let project = Project::test(fs.clone(), [], cx).await;
916 let connection = NativeAgentConnection(
917 NativeAgent::new(
918 project.clone(),
919 Templates::new(),
920 None,
921 fs.clone(),
922 &mut cx.to_async(),
923 )
924 .await
925 .unwrap(),
926 );
927
928 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
929
930 let acp_thread::AgentModelList::Grouped(models) = models else {
931 panic!("Unexpected model group");
932 };
933 assert_eq!(
934 models,
935 IndexMap::from_iter([(
936 AgentModelGroupName("Fake".into()),
937 vec![AgentModelInfo {
938 id: AgentModelId("fake/fake".into()),
939 name: "Fake".into(),
940 icon: Some(ui::IconName::ZedAssistant),
941 }]
942 )])
943 );
944 }
945
946 #[gpui::test]
947 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
948 init_test(cx);
949 let fs = FakeFs::new(cx.executor());
950 fs.create_dir(paths::settings_file().parent().unwrap())
951 .await
952 .unwrap();
953 fs.insert_file(
954 paths::settings_file(),
955 json!({
956 "agent": {
957 "default_model": {
958 "provider": "foo",
959 "model": "bar"
960 }
961 }
962 })
963 .to_string()
964 .into_bytes(),
965 )
966 .await;
967 let project = Project::test(fs.clone(), [], cx).await;
968
969 // Create the agent and connection
970 let agent = NativeAgent::new(
971 project.clone(),
972 Templates::new(),
973 None,
974 fs.clone(),
975 &mut cx.to_async(),
976 )
977 .await
978 .unwrap();
979 let connection = NativeAgentConnection(agent.clone());
980
981 // Create a thread/session
982 let acp_thread = cx
983 .update(|cx| {
984 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
985 })
986 .await
987 .unwrap();
988
989 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
990
991 // Select a model
992 let model_id = AgentModelId("fake/fake".into());
993 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
994 .await
995 .unwrap();
996
997 // Verify the thread has the selected model
998 agent.read_with(cx, |agent, _| {
999 let session = agent.sessions.get(&session_id).unwrap();
1000 session.thread.read_with(cx, |thread, _| {
1001 assert_eq!(thread.model().unwrap().id().0, "fake");
1002 });
1003 });
1004
1005 cx.run_until_parked();
1006
1007 // Verify settings file was updated
1008 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1009 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1010
1011 // Check that the agent settings contain the selected model
1012 assert_eq!(
1013 settings_json["agent"]["default_model"]["model"],
1014 json!("fake")
1015 );
1016 assert_eq!(
1017 settings_json["agent"]["default_model"]["provider"],
1018 json!("fake")
1019 );
1020 }
1021
1022 fn init_test(cx: &mut TestAppContext) {
1023 env_logger::try_init().ok();
1024 cx.update(|cx| {
1025 let settings_store = SettingsStore::test(cx);
1026 cx.set_global(settings_store);
1027 Project::init_settings(cx);
1028 agent_settings::init(cx);
1029 language::init(cx);
1030 LanguageModelRegistry::test(cx);
1031 });
1032 }
1033}