1use crate::{
2 AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
3 DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
4 MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
5 ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
6};
7use acp_thread::AgentModelSelector;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::cell::RefCell;
26use std::collections::HashMap;
27use std::path::Path;
28use std::rc::Rc;
29use std::sync::Arc;
30use util::ResultExt;
31
32const RULES_FILE_NAMES: [&'static str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 _subscription: Subscription,
55}
56
57pub struct LanguageModels {
58 /// Access language model by ID
59 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
60 /// Cached list for returning language model information
61 model_list: acp_thread::AgentModelList,
62 refresh_models_rx: watch::Receiver<()>,
63 refresh_models_tx: watch::Sender<()>,
64}
65
66impl LanguageModels {
67 fn new(cx: &App) -> Self {
68 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
69 let mut this = Self {
70 models: HashMap::default(),
71 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
72 refresh_models_rx,
73 refresh_models_tx,
74 };
75 this.refresh_list(cx);
76 this
77 }
78
79 fn refresh_list(&mut self, cx: &App) {
80 let providers = LanguageModelRegistry::global(cx)
81 .read(cx)
82 .providers()
83 .into_iter()
84 .filter(|provider| provider.is_authenticated(cx))
85 .collect::<Vec<_>>();
86
87 let mut language_model_list = IndexMap::default();
88 let mut recommended_models = HashSet::default();
89
90 let mut recommended = Vec::new();
91 for provider in &providers {
92 for model in provider.recommended_models(cx) {
93 recommended_models.insert(model.id());
94 recommended.push(Self::map_language_model_to_info(&model, provider));
95 }
96 }
97 if !recommended.is_empty() {
98 language_model_list.insert(
99 acp_thread::AgentModelGroupName("Recommended".into()),
100 recommended,
101 );
102 }
103
104 let mut models = HashMap::default();
105 for provider in providers {
106 let mut provider_models = Vec::new();
107 for model in provider.provided_models(cx) {
108 let model_info = Self::map_language_model_to_info(&model, &provider);
109 let model_id = model_info.id.clone();
110 if !recommended_models.contains(&model.id()) {
111 provider_models.push(model_info);
112 }
113 models.insert(model_id, model);
114 }
115 if !provider_models.is_empty() {
116 language_model_list.insert(
117 acp_thread::AgentModelGroupName(provider.name().0.clone()),
118 provider_models,
119 );
120 }
121 }
122
123 self.models = models;
124 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
125 self.refresh_models_tx.send(()).ok();
126 }
127
128 fn watch(&self) -> watch::Receiver<()> {
129 self.refresh_models_rx.clone()
130 }
131
132 pub fn model_from_id(
133 &self,
134 model_id: &acp_thread::AgentModelId,
135 ) -> Option<Arc<dyn LanguageModel>> {
136 self.models.get(model_id).cloned()
137 }
138
139 fn map_language_model_to_info(
140 model: &Arc<dyn LanguageModel>,
141 provider: &Arc<dyn LanguageModelProvider>,
142 ) -> acp_thread::AgentModelInfo {
143 acp_thread::AgentModelInfo {
144 id: Self::model_id(model),
145 name: model.name().0,
146 icon: Some(provider.icon()),
147 }
148 }
149
150 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
151 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
152 }
153}
154
155pub struct NativeAgent {
156 /// Session ID -> Session mapping
157 sessions: HashMap<acp::SessionId, Session>,
158 /// Shared project context for all threads
159 project_context: Rc<RefCell<ProjectContext>>,
160 project_context_needs_refresh: watch::Sender<()>,
161 _maintain_project_context: Task<Result<()>>,
162 context_server_registry: Entity<ContextServerRegistry>,
163 /// Shared templates for all threads
164 templates: Arc<Templates>,
165 /// Cached model information
166 models: LanguageModels,
167 project: Entity<Project>,
168 prompt_store: Option<Entity<PromptStore>>,
169 fs: Arc<dyn Fs>,
170 _subscriptions: Vec<Subscription>,
171}
172
173impl NativeAgent {
174 pub async fn new(
175 project: Entity<Project>,
176 templates: Arc<Templates>,
177 prompt_store: Option<Entity<PromptStore>>,
178 fs: Arc<dyn Fs>,
179 cx: &mut AsyncApp,
180 ) -> Result<Entity<NativeAgent>> {
181 log::info!("Creating new NativeAgent");
182
183 let project_context = cx
184 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
185 .await;
186
187 cx.new(|cx| {
188 let mut subscriptions = vec![
189 cx.subscribe(&project, Self::handle_project_event),
190 cx.subscribe(
191 &LanguageModelRegistry::global(cx),
192 Self::handle_models_updated_event,
193 ),
194 ];
195 if let Some(prompt_store) = prompt_store.as_ref() {
196 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
197 }
198
199 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
200 watch::channel(());
201 Self {
202 sessions: HashMap::new(),
203 project_context: Rc::new(RefCell::new(project_context)),
204 project_context_needs_refresh: project_context_needs_refresh_tx,
205 _maintain_project_context: cx.spawn(async move |this, cx| {
206 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
207 }),
208 context_server_registry: cx.new(|cx| {
209 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
210 }),
211 templates,
212 models: LanguageModels::new(cx),
213 project,
214 prompt_store,
215 fs,
216 _subscriptions: subscriptions,
217 }
218 })
219 }
220
221 pub fn models(&self) -> &LanguageModels {
222 &self.models
223 }
224
225 async fn maintain_project_context(
226 this: WeakEntity<Self>,
227 mut needs_refresh: watch::Receiver<()>,
228 cx: &mut AsyncApp,
229 ) -> Result<()> {
230 while needs_refresh.changed().await.is_ok() {
231 let project_context = this
232 .update(cx, |this, cx| {
233 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
234 })?
235 .await;
236 this.update(cx, |this, _| this.project_context.replace(project_context))?;
237 }
238
239 Ok(())
240 }
241
242 fn build_project_context(
243 project: &Entity<Project>,
244 prompt_store: Option<&Entity<PromptStore>>,
245 cx: &mut App,
246 ) -> Task<ProjectContext> {
247 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
248 let worktree_tasks = worktrees
249 .into_iter()
250 .map(|worktree| {
251 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
252 })
253 .collect::<Vec<_>>();
254 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
255 prompt_store.read_with(cx, |prompt_store, cx| {
256 let prompts = prompt_store.default_prompt_metadata();
257 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
258 let contents = prompt_store.load(prompt_metadata.id, cx);
259 async move { (contents.await, prompt_metadata) }
260 });
261 cx.background_spawn(future::join_all(load_tasks))
262 })
263 } else {
264 Task::ready(vec![])
265 };
266
267 cx.spawn(async move |_cx| {
268 let (worktrees, default_user_rules) =
269 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
270
271 let worktrees = worktrees
272 .into_iter()
273 .map(|(worktree, _rules_error)| {
274 // TODO: show error message
275 // if let Some(rules_error) = rules_error {
276 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
277 // }
278 worktree
279 })
280 .collect::<Vec<_>>();
281
282 let default_user_rules = default_user_rules
283 .into_iter()
284 .flat_map(|(contents, prompt_metadata)| match contents {
285 Ok(contents) => Some(UserRulesContext {
286 uuid: match prompt_metadata.id {
287 PromptId::User { uuid } => uuid,
288 PromptId::EditWorkflow => return None,
289 },
290 title: prompt_metadata.title.map(|title| title.to_string()),
291 contents,
292 }),
293 Err(_err) => {
294 // TODO: show error message
295 // this.update(cx, |_, cx| {
296 // cx.emit(RulesLoadingError {
297 // message: format!("{err:?}").into(),
298 // });
299 // })
300 // .ok();
301 None
302 }
303 })
304 .collect::<Vec<_>>();
305
306 ProjectContext::new(worktrees, default_user_rules)
307 })
308 }
309
310 fn load_worktree_info_for_system_prompt(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
315 let tree = worktree.read(cx);
316 let root_name = tree.root_name().into();
317 let abs_path = tree.abs_path();
318
319 let mut context = WorktreeContext {
320 root_name,
321 abs_path,
322 rules_file: None,
323 };
324
325 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
326 let Some(rules_task) = rules_task else {
327 return Task::ready((context, None));
328 };
329
330 cx.spawn(async move |_| {
331 let (rules_file, rules_file_error) = match rules_task.await {
332 Ok(rules_file) => (Some(rules_file), None),
333 Err(err) => (
334 None,
335 Some(RulesLoadingError {
336 message: format!("{err}").into(),
337 }),
338 ),
339 };
340 context.rules_file = rules_file;
341 (context, rules_file_error)
342 })
343 }
344
345 fn load_worktree_rules_file(
346 worktree: Entity<Worktree>,
347 project: Entity<Project>,
348 cx: &mut App,
349 ) -> Option<Task<Result<RulesFileContext>>> {
350 let worktree = worktree.read(cx);
351 let worktree_id = worktree.id();
352 let selected_rules_file = RULES_FILE_NAMES
353 .into_iter()
354 .filter_map(|name| {
355 worktree
356 .entry_for_path(name)
357 .filter(|entry| entry.is_file())
358 .map(|entry| entry.path.clone())
359 })
360 .next();
361
362 // Note that Cline supports `.clinerules` being a directory, but that is not currently
363 // supported. This doesn't seem to occur often in GitHub repositories.
364 selected_rules_file.map(|path_in_worktree| {
365 let project_path = ProjectPath {
366 worktree_id,
367 path: path_in_worktree.clone(),
368 };
369 let buffer_task =
370 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
371 let rope_task = cx.spawn(async move |cx| {
372 buffer_task.await?.read_with(cx, |buffer, cx| {
373 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
374 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
375 })?
376 });
377 // Build a string from the rope on a background thread.
378 cx.background_spawn(async move {
379 let (project_entry_id, rope) = rope_task.await?;
380 anyhow::Ok(RulesFileContext {
381 path_in_worktree,
382 text: rope.to_string().trim().to_string(),
383 project_entry_id: project_entry_id.to_usize(),
384 })
385 })
386 })
387 }
388
389 fn handle_project_event(
390 &mut self,
391 _project: Entity<Project>,
392 event: &project::Event,
393 _cx: &mut Context<Self>,
394 ) {
395 match event {
396 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
397 self.project_context_needs_refresh.send(()).ok();
398 }
399 project::Event::WorktreeUpdatedEntries(_, items) => {
400 if items.iter().any(|(path, _, _)| {
401 RULES_FILE_NAMES
402 .iter()
403 .any(|name| path.as_ref() == Path::new(name))
404 }) {
405 self.project_context_needs_refresh.send(()).ok();
406 }
407 }
408 _ => {}
409 }
410 }
411
412 fn handle_prompts_updated_event(
413 &mut self,
414 _prompt_store: Entity<PromptStore>,
415 _event: &prompt_store::PromptsUpdatedEvent,
416 _cx: &mut Context<Self>,
417 ) {
418 self.project_context_needs_refresh.send(()).ok();
419 }
420
421 fn handle_models_updated_event(
422 &mut self,
423 _registry: Entity<LanguageModelRegistry>,
424 _event: &language_model::Event,
425 cx: &mut Context<Self>,
426 ) {
427 self.models.refresh_list(cx);
428
429 let default_model = LanguageModelRegistry::read_global(cx)
430 .default_model()
431 .map(|m| m.model.clone());
432
433 for session in self.sessions.values_mut() {
434 session.thread.update(cx, |thread, cx| {
435 if thread.model().is_none()
436 && let Some(model) = default_model.clone()
437 {
438 thread.set_model(model);
439 cx.notify();
440 }
441 });
442 }
443 }
444}
445
446/// Wrapper struct that implements the AgentConnection trait
447#[derive(Clone)]
448pub struct NativeAgentConnection(pub Entity<NativeAgent>);
449
450impl NativeAgentConnection {
451 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
452 self.0
453 .read(cx)
454 .sessions
455 .get(session_id)
456 .map(|session| session.thread.clone())
457 }
458
459 fn run_turn(
460 &self,
461 session_id: acp::SessionId,
462 cx: &mut App,
463 f: impl 'static
464 + FnOnce(
465 Entity<Thread>,
466 &mut App,
467 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
468 ) -> Task<Result<acp::PromptResponse>> {
469 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
470 agent
471 .sessions
472 .get_mut(&session_id)
473 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
474 }) else {
475 return Task::ready(Err(anyhow!("Session not found")));
476 };
477 log::debug!("Found session for: {}", session_id);
478
479 let mut response_stream = match f(thread, cx) {
480 Ok(stream) => stream,
481 Err(err) => return Task::ready(Err(err)),
482 };
483 cx.spawn(async move |cx| {
484 // Handle response stream and forward to session.acp_thread
485 while let Some(result) = response_stream.next().await {
486 match result {
487 Ok(event) => {
488 log::trace!("Received completion event: {:?}", event);
489
490 match event {
491 AgentResponseEvent::Text(text) => {
492 acp_thread.update(cx, |thread, cx| {
493 thread.push_assistant_content_block(
494 acp::ContentBlock::Text(acp::TextContent {
495 text,
496 annotations: None,
497 }),
498 false,
499 cx,
500 )
501 })?;
502 }
503 AgentResponseEvent::Thinking(text) => {
504 acp_thread.update(cx, |thread, cx| {
505 thread.push_assistant_content_block(
506 acp::ContentBlock::Text(acp::TextContent {
507 text,
508 annotations: None,
509 }),
510 true,
511 cx,
512 )
513 })?;
514 }
515 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
516 tool_call,
517 options,
518 response,
519 }) => {
520 let recv = acp_thread.update(cx, |thread, cx| {
521 thread.request_tool_call_authorization(tool_call, options, cx)
522 })?;
523 cx.background_spawn(async move {
524 if let Some(recv) = recv.log_err()
525 && let Some(option) = recv
526 .await
527 .context("authorization sender was dropped")
528 .log_err()
529 {
530 response
531 .send(option)
532 .map(|_| anyhow!("authorization receiver was dropped"))
533 .log_err();
534 }
535 })
536 .detach();
537 }
538 AgentResponseEvent::ToolCall(tool_call) => {
539 acp_thread.update(cx, |thread, cx| {
540 thread.upsert_tool_call(tool_call, cx)
541 })??;
542 }
543 AgentResponseEvent::ToolCallUpdate(update) => {
544 acp_thread.update(cx, |thread, cx| {
545 thread.update_tool_call(update, cx)
546 })??;
547 }
548 AgentResponseEvent::Stop(stop_reason) => {
549 log::debug!("Assistant message complete: {:?}", stop_reason);
550 return Ok(acp::PromptResponse { stop_reason });
551 }
552 }
553 }
554 Err(e) => {
555 log::error!("Error in model response stream: {:?}", e);
556 return Err(e);
557 }
558 }
559 }
560
561 log::info!("Response stream completed");
562 anyhow::Ok(acp::PromptResponse {
563 stop_reason: acp::StopReason::EndTurn,
564 })
565 })
566 }
567}
568
569impl AgentModelSelector for NativeAgentConnection {
570 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
571 log::debug!("NativeAgentConnection::list_models called");
572 let list = self.0.read(cx).models.model_list.clone();
573 Task::ready(if list.is_empty() {
574 Err(anyhow::anyhow!("No models available"))
575 } else {
576 Ok(list)
577 })
578 }
579
580 fn select_model(
581 &self,
582 session_id: acp::SessionId,
583 model_id: acp_thread::AgentModelId,
584 cx: &mut App,
585 ) -> Task<Result<()>> {
586 log::info!("Setting model for session {}: {}", session_id, model_id);
587 let Some(thread) = self
588 .0
589 .read(cx)
590 .sessions
591 .get(&session_id)
592 .map(|session| session.thread.clone())
593 else {
594 return Task::ready(Err(anyhow!("Session not found")));
595 };
596
597 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
598 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
599 };
600
601 thread.update(cx, |thread, _cx| {
602 thread.set_model(model.clone());
603 });
604
605 update_settings_file::<AgentSettings>(
606 self.0.read(cx).fs.clone(),
607 cx,
608 move |settings, _cx| {
609 settings.set_model(model);
610 },
611 );
612
613 Task::ready(Ok(()))
614 }
615
616 fn selected_model(
617 &self,
618 session_id: &acp::SessionId,
619 cx: &mut App,
620 ) -> Task<Result<acp_thread::AgentModelInfo>> {
621 let session_id = session_id.clone();
622
623 let Some(thread) = self
624 .0
625 .read(cx)
626 .sessions
627 .get(&session_id)
628 .map(|session| session.thread.clone())
629 else {
630 return Task::ready(Err(anyhow!("Session not found")));
631 };
632 let Some(model) = thread.read(cx).model() else {
633 return Task::ready(Err(anyhow!("Model not found")));
634 };
635 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
636 else {
637 return Task::ready(Err(anyhow!("Provider not found")));
638 };
639 Task::ready(Ok(LanguageModels::map_language_model_to_info(
640 model, &provider,
641 )))
642 }
643
644 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
645 self.0.read(cx).models.watch()
646 }
647}
648
649impl acp_thread::AgentConnection for NativeAgentConnection {
650 fn new_thread(
651 self: Rc<Self>,
652 project: Entity<Project>,
653 cwd: &Path,
654 cx: &mut App,
655 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
656 let agent = self.0.clone();
657 log::info!("Creating new thread for project at: {:?}", cwd);
658
659 cx.spawn(async move |cx| {
660 log::debug!("Starting thread creation in async context");
661
662 // Generate session ID
663 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
664 log::info!("Created session with ID: {}", session_id);
665
666 // Create AcpThread
667 let acp_thread = cx.update(|cx| {
668 cx.new(|cx| {
669 acp_thread::AcpThread::new(
670 "agent2",
671 self.clone(),
672 project.clone(),
673 session_id.clone(),
674 cx,
675 )
676 })
677 })?;
678 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
679
680 // Create Thread
681 let thread = agent.update(
682 cx,
683 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
684 // Fetch default model from registry settings
685 let registry = LanguageModelRegistry::read_global(cx);
686
687 // Log available models for debugging
688 let available_count = registry.available_models(cx).count();
689 log::debug!("Total available models: {}", available_count);
690
691 let default_model = registry.default_model().and_then(|default_model| {
692 agent
693 .models
694 .model_from_id(&LanguageModels::model_id(&default_model.model))
695 });
696
697 let thread = cx.new(|cx| {
698 let mut thread = Thread::new(
699 project.clone(),
700 agent.project_context.clone(),
701 agent.context_server_registry.clone(),
702 action_log.clone(),
703 agent.templates.clone(),
704 default_model,
705 cx,
706 );
707 thread.add_tool(CopyPathTool::new(project.clone()));
708 thread.add_tool(CreateDirectoryTool::new(project.clone()));
709 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
710 thread.add_tool(DiagnosticsTool::new(project.clone()));
711 thread.add_tool(EditFileTool::new(cx.entity()));
712 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
713 thread.add_tool(FindPathTool::new(project.clone()));
714 thread.add_tool(GrepTool::new(project.clone()));
715 thread.add_tool(ListDirectoryTool::new(project.clone()));
716 thread.add_tool(MovePathTool::new(project.clone()));
717 thread.add_tool(NowTool);
718 thread.add_tool(OpenTool::new(project.clone()));
719 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
720 thread.add_tool(TerminalTool::new(project.clone(), cx));
721 thread.add_tool(ThinkingTool);
722 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
723 thread
724 });
725
726 Ok(thread)
727 },
728 )??;
729
730 // Store the session
731 agent.update(cx, |agent, cx| {
732 agent.sessions.insert(
733 session_id,
734 Session {
735 thread,
736 acp_thread: acp_thread.downgrade(),
737 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
738 this.sessions.remove(acp_thread.session_id());
739 }),
740 },
741 );
742 })?;
743
744 Ok(acp_thread)
745 })
746 }
747
748 fn auth_methods(&self) -> &[acp::AuthMethod] {
749 &[] // No auth for in-process
750 }
751
752 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
753 Task::ready(Ok(()))
754 }
755
756 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
757 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
758 }
759
760 fn prompt(
761 &self,
762 id: Option<acp_thread::UserMessageId>,
763 params: acp::PromptRequest,
764 cx: &mut App,
765 ) -> Task<Result<acp::PromptResponse>> {
766 let id = id.expect("UserMessageId is required");
767 let session_id = params.session_id.clone();
768 log::info!("Received prompt request for session: {}", session_id);
769 log::debug!("Prompt blocks count: {}", params.prompt.len());
770
771 self.run_turn(session_id, cx, |thread, cx| {
772 let content: Vec<UserMessageContent> = params
773 .prompt
774 .into_iter()
775 .map(Into::into)
776 .collect::<Vec<_>>();
777 log::info!("Converted prompt to message: {} chars", content.len());
778 log::debug!("Message id: {:?}", id);
779 log::debug!("Message content: {:?}", content);
780
781 thread.update(cx, |thread, cx| thread.send(id, content, cx))
782 })
783 }
784
785 fn resume(
786 &self,
787 session_id: &acp::SessionId,
788 _cx: &mut App,
789 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
790 Some(Rc::new(NativeAgentSessionResume {
791 connection: self.clone(),
792 session_id: session_id.clone(),
793 }) as _)
794 }
795
796 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
797 log::info!("Cancelling on session: {}", session_id);
798 self.0.update(cx, |agent, cx| {
799 if let Some(agent) = agent.sessions.get(session_id) {
800 agent.thread.update(cx, |thread, _cx| thread.cancel());
801 }
802 });
803 }
804
805 fn session_editor(
806 &self,
807 session_id: &agent_client_protocol::SessionId,
808 cx: &mut App,
809 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
810 self.0.update(cx, |agent, _cx| {
811 agent
812 .sessions
813 .get(session_id)
814 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
815 })
816 }
817
818 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
819 self
820 }
821}
822
823struct NativeAgentSessionEditor(Entity<Thread>);
824
825impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
826 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
827 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
828 }
829}
830
831struct NativeAgentSessionResume {
832 connection: NativeAgentConnection,
833 session_id: acp::SessionId,
834}
835
836impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
837 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
838 self.connection
839 .run_turn(self.session_id.clone(), cx, |thread, cx| {
840 thread.update(cx, |thread, cx| thread.resume(cx))
841 })
842 }
843}
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
849 use fs::FakeFs;
850 use gpui::TestAppContext;
851 use serde_json::json;
852 use settings::SettingsStore;
853
854 #[gpui::test]
855 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
856 init_test(cx);
857 let fs = FakeFs::new(cx.executor());
858 fs.insert_tree(
859 "/",
860 json!({
861 "a": {}
862 }),
863 )
864 .await;
865 let project = Project::test(fs.clone(), [], cx).await;
866 let agent = NativeAgent::new(
867 project.clone(),
868 Templates::new(),
869 None,
870 fs.clone(),
871 &mut cx.to_async(),
872 )
873 .await
874 .unwrap();
875 agent.read_with(cx, |agent, _| {
876 assert_eq!(agent.project_context.borrow().worktrees, vec![])
877 });
878
879 let worktree = project
880 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
881 .await
882 .unwrap();
883 cx.run_until_parked();
884 agent.read_with(cx, |agent, _| {
885 assert_eq!(
886 agent.project_context.borrow().worktrees,
887 vec![WorktreeContext {
888 root_name: "a".into(),
889 abs_path: Path::new("/a").into(),
890 rules_file: None
891 }]
892 )
893 });
894
895 // Creating `/a/.rules` updates the project context.
896 fs.insert_file("/a/.rules", Vec::new()).await;
897 cx.run_until_parked();
898 agent.read_with(cx, |agent, cx| {
899 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
900 assert_eq!(
901 agent.project_context.borrow().worktrees,
902 vec![WorktreeContext {
903 root_name: "a".into(),
904 abs_path: Path::new("/a").into(),
905 rules_file: Some(RulesFileContext {
906 path_in_worktree: Path::new(".rules").into(),
907 text: "".into(),
908 project_entry_id: rules_entry.id.to_usize()
909 })
910 }]
911 )
912 });
913 }
914
915 #[gpui::test]
916 async fn test_listing_models(cx: &mut TestAppContext) {
917 init_test(cx);
918 let fs = FakeFs::new(cx.executor());
919 fs.insert_tree("/", json!({ "a": {} })).await;
920 let project = Project::test(fs.clone(), [], cx).await;
921 let connection = NativeAgentConnection(
922 NativeAgent::new(
923 project.clone(),
924 Templates::new(),
925 None,
926 fs.clone(),
927 &mut cx.to_async(),
928 )
929 .await
930 .unwrap(),
931 );
932
933 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
934
935 let acp_thread::AgentModelList::Grouped(models) = models else {
936 panic!("Unexpected model group");
937 };
938 assert_eq!(
939 models,
940 IndexMap::from_iter([(
941 AgentModelGroupName("Fake".into()),
942 vec![AgentModelInfo {
943 id: AgentModelId("fake/fake".into()),
944 name: "Fake".into(),
945 icon: Some(ui::IconName::ZedAssistant),
946 }]
947 )])
948 );
949 }
950
951 #[gpui::test]
952 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
953 init_test(cx);
954 let fs = FakeFs::new(cx.executor());
955 fs.create_dir(paths::settings_file().parent().unwrap())
956 .await
957 .unwrap();
958 fs.insert_file(
959 paths::settings_file(),
960 json!({
961 "agent": {
962 "default_model": {
963 "provider": "foo",
964 "model": "bar"
965 }
966 }
967 })
968 .to_string()
969 .into_bytes(),
970 )
971 .await;
972 let project = Project::test(fs.clone(), [], cx).await;
973
974 // Create the agent and connection
975 let agent = NativeAgent::new(
976 project.clone(),
977 Templates::new(),
978 None,
979 fs.clone(),
980 &mut cx.to_async(),
981 )
982 .await
983 .unwrap();
984 let connection = NativeAgentConnection(agent.clone());
985
986 // Create a thread/session
987 let acp_thread = cx
988 .update(|cx| {
989 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
990 })
991 .await
992 .unwrap();
993
994 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
995
996 // Select a model
997 let model_id = AgentModelId("fake/fake".into());
998 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
999 .await
1000 .unwrap();
1001
1002 // Verify the thread has the selected model
1003 agent.read_with(cx, |agent, _| {
1004 let session = agent.sessions.get(&session_id).unwrap();
1005 session.thread.read_with(cx, |thread, _| {
1006 assert_eq!(thread.model().unwrap().id().0, "fake");
1007 });
1008 });
1009
1010 cx.run_until_parked();
1011
1012 // Verify settings file was updated
1013 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1014 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1015
1016 // Check that the agent settings contain the selected model
1017 assert_eq!(
1018 settings_json["agent"]["default_model"]["model"],
1019 json!("fake")
1020 );
1021 assert_eq!(
1022 settings_json["agent"]["default_model"]["provider"],
1023 json!("fake")
1024 );
1025 }
1026
1027 fn init_test(cx: &mut TestAppContext) {
1028 env_logger::try_init().ok();
1029 cx.update(|cx| {
1030 let settings_store = SettingsStore::test(cx);
1031 cx.set_global(settings_store);
1032 Project::init_settings(cx);
1033 agent_settings::init(cx);
1034 language::init(cx);
1035 LanguageModelRegistry::test(cx);
1036 });
1037 }
1038}