1use crate::{
2 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
3 EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
4 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
5 UserMessageContent, WebSearchTool, templates::Templates,
6};
7use acp_thread::AgentModelSelector;
8use action_log::ActionLog;
9use agent_client_protocol as acp;
10use agent_settings::AgentSettings;
11use anyhow::{Context as _, Result, anyhow};
12use collections::{HashSet, IndexMap};
13use fs::Fs;
14use futures::channel::mpsc;
15use futures::{StreamExt, future};
16use gpui::{
17 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
18};
19use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
20use project::{Project, ProjectItem, ProjectPath, Worktree};
21use prompt_store::{
22 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
23};
24use settings::update_settings_file;
25use std::any::Any;
26use std::collections::HashMap;
27use std::path::Path;
28use std::rc::Rc;
29use std::sync::Arc;
30use util::ResultExt;
31
32const RULES_FILE_NAMES: [&'static str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 _subscription: Subscription,
55}
56
57pub struct LanguageModels {
58 /// Access language model by ID
59 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
60 /// Cached list for returning language model information
61 model_list: acp_thread::AgentModelList,
62 refresh_models_rx: watch::Receiver<()>,
63 refresh_models_tx: watch::Sender<()>,
64}
65
66impl LanguageModels {
67 fn new(cx: &App) -> Self {
68 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
69 let mut this = Self {
70 models: HashMap::default(),
71 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
72 refresh_models_rx,
73 refresh_models_tx,
74 };
75 this.refresh_list(cx);
76 this
77 }
78
79 fn refresh_list(&mut self, cx: &App) {
80 let providers = LanguageModelRegistry::global(cx)
81 .read(cx)
82 .providers()
83 .into_iter()
84 .filter(|provider| provider.is_authenticated(cx))
85 .collect::<Vec<_>>();
86
87 let mut language_model_list = IndexMap::default();
88 let mut recommended_models = HashSet::default();
89
90 let mut recommended = Vec::new();
91 for provider in &providers {
92 for model in provider.recommended_models(cx) {
93 recommended_models.insert(model.id());
94 recommended.push(Self::map_language_model_to_info(&model, provider));
95 }
96 }
97 if !recommended.is_empty() {
98 language_model_list.insert(
99 acp_thread::AgentModelGroupName("Recommended".into()),
100 recommended,
101 );
102 }
103
104 let mut models = HashMap::default();
105 for provider in providers {
106 let mut provider_models = Vec::new();
107 for model in provider.provided_models(cx) {
108 let model_info = Self::map_language_model_to_info(&model, &provider);
109 let model_id = model_info.id.clone();
110 if !recommended_models.contains(&model.id()) {
111 provider_models.push(model_info);
112 }
113 models.insert(model_id, model);
114 }
115 if !provider_models.is_empty() {
116 language_model_list.insert(
117 acp_thread::AgentModelGroupName(provider.name().0.clone()),
118 provider_models,
119 );
120 }
121 }
122
123 self.models = models;
124 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
125 self.refresh_models_tx.send(()).ok();
126 }
127
128 fn watch(&self) -> watch::Receiver<()> {
129 self.refresh_models_rx.clone()
130 }
131
132 pub fn model_from_id(
133 &self,
134 model_id: &acp_thread::AgentModelId,
135 ) -> Option<Arc<dyn LanguageModel>> {
136 self.models.get(model_id).cloned()
137 }
138
139 fn map_language_model_to_info(
140 model: &Arc<dyn LanguageModel>,
141 provider: &Arc<dyn LanguageModelProvider>,
142 ) -> acp_thread::AgentModelInfo {
143 acp_thread::AgentModelInfo {
144 id: Self::model_id(model),
145 name: model.name().0,
146 icon: Some(provider.icon()),
147 }
148 }
149
150 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
151 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
152 }
153}
154
155pub struct NativeAgent {
156 /// Session ID -> Session mapping
157 sessions: HashMap<acp::SessionId, Session>,
158 /// Shared project context for all threads
159 project_context: Entity<ProjectContext>,
160 project_context_needs_refresh: watch::Sender<()>,
161 _maintain_project_context: Task<Result<()>>,
162 context_server_registry: Entity<ContextServerRegistry>,
163 /// Shared templates for all threads
164 templates: Arc<Templates>,
165 /// Cached model information
166 models: LanguageModels,
167 project: Entity<Project>,
168 prompt_store: Option<Entity<PromptStore>>,
169 fs: Arc<dyn Fs>,
170 _subscriptions: Vec<Subscription>,
171}
172
173impl NativeAgent {
174 pub async fn new(
175 project: Entity<Project>,
176 templates: Arc<Templates>,
177 prompt_store: Option<Entity<PromptStore>>,
178 fs: Arc<dyn Fs>,
179 cx: &mut AsyncApp,
180 ) -> Result<Entity<NativeAgent>> {
181 log::info!("Creating new NativeAgent");
182
183 let project_context = cx
184 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
185 .await;
186
187 cx.new(|cx| {
188 let mut subscriptions = vec![
189 cx.subscribe(&project, Self::handle_project_event),
190 cx.subscribe(
191 &LanguageModelRegistry::global(cx),
192 Self::handle_models_updated_event,
193 ),
194 ];
195 if let Some(prompt_store) = prompt_store.as_ref() {
196 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
197 }
198
199 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
200 watch::channel(());
201 Self {
202 sessions: HashMap::new(),
203 project_context: cx.new(|_| project_context),
204 project_context_needs_refresh: project_context_needs_refresh_tx,
205 _maintain_project_context: cx.spawn(async move |this, cx| {
206 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
207 }),
208 context_server_registry: cx.new(|cx| {
209 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
210 }),
211 templates,
212 models: LanguageModels::new(cx),
213 project,
214 prompt_store,
215 fs,
216 _subscriptions: subscriptions,
217 }
218 })
219 }
220
221 pub fn models(&self) -> &LanguageModels {
222 &self.models
223 }
224
225 async fn maintain_project_context(
226 this: WeakEntity<Self>,
227 mut needs_refresh: watch::Receiver<()>,
228 cx: &mut AsyncApp,
229 ) -> Result<()> {
230 while needs_refresh.changed().await.is_ok() {
231 let project_context = this
232 .update(cx, |this, cx| {
233 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
234 })?
235 .await;
236 this.update(cx, |this, cx| {
237 this.project_context = cx.new(|_| project_context);
238 })?;
239 }
240
241 Ok(())
242 }
243
244 fn build_project_context(
245 project: &Entity<Project>,
246 prompt_store: Option<&Entity<PromptStore>>,
247 cx: &mut App,
248 ) -> Task<ProjectContext> {
249 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
250 let worktree_tasks = worktrees
251 .into_iter()
252 .map(|worktree| {
253 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
254 })
255 .collect::<Vec<_>>();
256 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
257 prompt_store.read_with(cx, |prompt_store, cx| {
258 let prompts = prompt_store.default_prompt_metadata();
259 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
260 let contents = prompt_store.load(prompt_metadata.id, cx);
261 async move { (contents.await, prompt_metadata) }
262 });
263 cx.background_spawn(future::join_all(load_tasks))
264 })
265 } else {
266 Task::ready(vec![])
267 };
268
269 cx.spawn(async move |_cx| {
270 let (worktrees, default_user_rules) =
271 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
272
273 let worktrees = worktrees
274 .into_iter()
275 .map(|(worktree, _rules_error)| {
276 // TODO: show error message
277 // if let Some(rules_error) = rules_error {
278 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
279 // }
280 worktree
281 })
282 .collect::<Vec<_>>();
283
284 let default_user_rules = default_user_rules
285 .into_iter()
286 .flat_map(|(contents, prompt_metadata)| match contents {
287 Ok(contents) => Some(UserRulesContext {
288 uuid: match prompt_metadata.id {
289 PromptId::User { uuid } => uuid,
290 PromptId::EditWorkflow => return None,
291 },
292 title: prompt_metadata.title.map(|title| title.to_string()),
293 contents,
294 }),
295 Err(_err) => {
296 // TODO: show error message
297 // this.update(cx, |_, cx| {
298 // cx.emit(RulesLoadingError {
299 // message: format!("{err:?}").into(),
300 // });
301 // })
302 // .ok();
303 None
304 }
305 })
306 .collect::<Vec<_>>();
307
308 ProjectContext::new(worktrees, default_user_rules)
309 })
310 }
311
312 fn load_worktree_info_for_system_prompt(
313 worktree: Entity<Worktree>,
314 project: Entity<Project>,
315 cx: &mut App,
316 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
317 let tree = worktree.read(cx);
318 let root_name = tree.root_name().into();
319 let abs_path = tree.abs_path();
320
321 let mut context = WorktreeContext {
322 root_name,
323 abs_path,
324 rules_file: None,
325 };
326
327 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
328 let Some(rules_task) = rules_task else {
329 return Task::ready((context, None));
330 };
331
332 cx.spawn(async move |_| {
333 let (rules_file, rules_file_error) = match rules_task.await {
334 Ok(rules_file) => (Some(rules_file), None),
335 Err(err) => (
336 None,
337 Some(RulesLoadingError {
338 message: format!("{err}").into(),
339 }),
340 ),
341 };
342 context.rules_file = rules_file;
343 (context, rules_file_error)
344 })
345 }
346
347 fn load_worktree_rules_file(
348 worktree: Entity<Worktree>,
349 project: Entity<Project>,
350 cx: &mut App,
351 ) -> Option<Task<Result<RulesFileContext>>> {
352 let worktree = worktree.read(cx);
353 let worktree_id = worktree.id();
354 let selected_rules_file = RULES_FILE_NAMES
355 .into_iter()
356 .filter_map(|name| {
357 worktree
358 .entry_for_path(name)
359 .filter(|entry| entry.is_file())
360 .map(|entry| entry.path.clone())
361 })
362 .next();
363
364 // Note that Cline supports `.clinerules` being a directory, but that is not currently
365 // supported. This doesn't seem to occur often in GitHub repositories.
366 selected_rules_file.map(|path_in_worktree| {
367 let project_path = ProjectPath {
368 worktree_id,
369 path: path_in_worktree.clone(),
370 };
371 let buffer_task =
372 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
373 let rope_task = cx.spawn(async move |cx| {
374 buffer_task.await?.read_with(cx, |buffer, cx| {
375 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
376 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
377 })?
378 });
379 // Build a string from the rope on a background thread.
380 cx.background_spawn(async move {
381 let (project_entry_id, rope) = rope_task.await?;
382 anyhow::Ok(RulesFileContext {
383 path_in_worktree,
384 text: rope.to_string().trim().to_string(),
385 project_entry_id: project_entry_id.to_usize(),
386 })
387 })
388 })
389 }
390
391 fn handle_project_event(
392 &mut self,
393 _project: Entity<Project>,
394 event: &project::Event,
395 _cx: &mut Context<Self>,
396 ) {
397 match event {
398 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
399 self.project_context_needs_refresh.send(()).ok();
400 }
401 project::Event::WorktreeUpdatedEntries(_, items) => {
402 if items.iter().any(|(path, _, _)| {
403 RULES_FILE_NAMES
404 .iter()
405 .any(|name| path.as_ref() == Path::new(name))
406 }) {
407 self.project_context_needs_refresh.send(()).ok();
408 }
409 }
410 _ => {}
411 }
412 }
413
414 fn handle_prompts_updated_event(
415 &mut self,
416 _prompt_store: Entity<PromptStore>,
417 _event: &prompt_store::PromptsUpdatedEvent,
418 _cx: &mut Context<Self>,
419 ) {
420 self.project_context_needs_refresh.send(()).ok();
421 }
422
423 fn handle_models_updated_event(
424 &mut self,
425 _registry: Entity<LanguageModelRegistry>,
426 _event: &language_model::Event,
427 cx: &mut Context<Self>,
428 ) {
429 self.models.refresh_list(cx);
430
431 let registry = LanguageModelRegistry::read_global(cx);
432 let default_model = registry.default_model().map(|m| m.model.clone());
433 let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
434
435 for session in self.sessions.values_mut() {
436 session.thread.update(cx, |thread, cx| {
437 if thread.model().is_none()
438 && let Some(model) = default_model.clone()
439 {
440 thread.set_model(model, cx);
441 cx.notify();
442 }
443 thread.set_summarization_model(summarization_model.clone(), cx);
444 });
445 }
446 }
447}
448
449/// Wrapper struct that implements the AgentConnection trait
450#[derive(Clone)]
451pub struct NativeAgentConnection(pub Entity<NativeAgent>);
452
453impl NativeAgentConnection {
454 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
455 self.0
456 .read(cx)
457 .sessions
458 .get(session_id)
459 .map(|session| session.thread.clone())
460 }
461
462 fn run_turn(
463 &self,
464 session_id: acp::SessionId,
465 cx: &mut App,
466 f: impl 'static
467 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
468 ) -> Task<Result<acp::PromptResponse>> {
469 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
470 agent
471 .sessions
472 .get_mut(&session_id)
473 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
474 }) else {
475 return Task::ready(Err(anyhow!("Session not found")));
476 };
477 log::debug!("Found session for: {}", session_id);
478
479 let mut response_stream = match f(thread, cx) {
480 Ok(stream) => stream,
481 Err(err) => return Task::ready(Err(err)),
482 };
483 cx.spawn(async move |cx| {
484 // Handle response stream and forward to session.acp_thread
485 while let Some(result) = response_stream.next().await {
486 match result {
487 Ok(event) => {
488 log::trace!("Received completion event: {:?}", event);
489
490 match event {
491 ThreadEvent::UserMessage(message) => {
492 acp_thread.update(cx, |thread, cx| {
493 for content in message.content {
494 thread.push_user_content_block(
495 Some(message.id.clone()),
496 content.into(),
497 cx,
498 );
499 }
500 })?;
501 }
502 ThreadEvent::AgentText(text) => {
503 acp_thread.update(cx, |thread, cx| {
504 thread.push_assistant_content_block(
505 acp::ContentBlock::Text(acp::TextContent {
506 text,
507 annotations: None,
508 }),
509 false,
510 cx,
511 )
512 })?;
513 }
514 ThreadEvent::AgentThinking(text) => {
515 acp_thread.update(cx, |thread, cx| {
516 thread.push_assistant_content_block(
517 acp::ContentBlock::Text(acp::TextContent {
518 text,
519 annotations: None,
520 }),
521 true,
522 cx,
523 )
524 })?;
525 }
526 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
527 tool_call,
528 options,
529 response,
530 }) => {
531 let recv = acp_thread.update(cx, |thread, cx| {
532 thread.request_tool_call_authorization(tool_call, options, cx)
533 })?;
534 cx.background_spawn(async move {
535 if let Some(recv) = recv.log_err()
536 && let Some(option) = recv
537 .await
538 .context("authorization sender was dropped")
539 .log_err()
540 {
541 response
542 .send(option)
543 .map(|_| anyhow!("authorization receiver was dropped"))
544 .log_err();
545 }
546 })
547 .detach();
548 }
549 ThreadEvent::ToolCall(tool_call) => {
550 acp_thread.update(cx, |thread, cx| {
551 thread.upsert_tool_call(tool_call, cx)
552 })??;
553 }
554 ThreadEvent::ToolCallUpdate(update) => {
555 acp_thread.update(cx, |thread, cx| {
556 thread.update_tool_call(update, cx)
557 })??;
558 }
559 ThreadEvent::TitleUpdate(title) => {
560 acp_thread
561 .update(cx, |thread, cx| thread.update_title(title, cx))??;
562 }
563 ThreadEvent::Retry(status) => {
564 acp_thread.update(cx, |thread, cx| {
565 thread.update_retry_status(status, cx)
566 })?;
567 }
568 ThreadEvent::Stop(stop_reason) => {
569 log::debug!("Assistant message complete: {:?}", stop_reason);
570 return Ok(acp::PromptResponse { stop_reason });
571 }
572 }
573 }
574 Err(e) => {
575 log::error!("Error in model response stream: {:?}", e);
576 return Err(e);
577 }
578 }
579 }
580
581 log::info!("Response stream completed");
582 anyhow::Ok(acp::PromptResponse {
583 stop_reason: acp::StopReason::EndTurn,
584 })
585 })
586 }
587}
588
589impl AgentModelSelector for NativeAgentConnection {
590 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
591 log::debug!("NativeAgentConnection::list_models called");
592 let list = self.0.read(cx).models.model_list.clone();
593 Task::ready(if list.is_empty() {
594 Err(anyhow::anyhow!("No models available"))
595 } else {
596 Ok(list)
597 })
598 }
599
600 fn select_model(
601 &self,
602 session_id: acp::SessionId,
603 model_id: acp_thread::AgentModelId,
604 cx: &mut App,
605 ) -> Task<Result<()>> {
606 log::info!("Setting model for session {}: {}", session_id, model_id);
607 let Some(thread) = self
608 .0
609 .read(cx)
610 .sessions
611 .get(&session_id)
612 .map(|session| session.thread.clone())
613 else {
614 return Task::ready(Err(anyhow!("Session not found")));
615 };
616
617 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
618 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
619 };
620
621 thread.update(cx, |thread, cx| {
622 thread.set_model(model.clone(), cx);
623 });
624
625 update_settings_file::<AgentSettings>(
626 self.0.read(cx).fs.clone(),
627 cx,
628 move |settings, _cx| {
629 settings.set_model(model);
630 },
631 );
632
633 Task::ready(Ok(()))
634 }
635
636 fn selected_model(
637 &self,
638 session_id: &acp::SessionId,
639 cx: &mut App,
640 ) -> Task<Result<acp_thread::AgentModelInfo>> {
641 let session_id = session_id.clone();
642
643 let Some(thread) = self
644 .0
645 .read(cx)
646 .sessions
647 .get(&session_id)
648 .map(|session| session.thread.clone())
649 else {
650 return Task::ready(Err(anyhow!("Session not found")));
651 };
652 let Some(model) = thread.read(cx).model() else {
653 return Task::ready(Err(anyhow!("Model not found")));
654 };
655 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
656 else {
657 return Task::ready(Err(anyhow!("Provider not found")));
658 };
659 Task::ready(Ok(LanguageModels::map_language_model_to_info(
660 model, &provider,
661 )))
662 }
663
664 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
665 self.0.read(cx).models.watch()
666 }
667}
668
669impl acp_thread::AgentConnection for NativeAgentConnection {
670 fn new_thread(
671 self: Rc<Self>,
672 project: Entity<Project>,
673 cwd: &Path,
674 cx: &mut App,
675 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
676 let agent = self.0.clone();
677 log::info!("Creating new thread for project at: {:?}", cwd);
678
679 cx.spawn(async move |cx| {
680 log::debug!("Starting thread creation in async context");
681
682 let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
683 // Create Thread
684 let thread = agent.update(
685 cx,
686 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
687 // Fetch default model from registry settings
688 let registry = LanguageModelRegistry::read_global(cx);
689 let language_registry = project.read(cx).languages().clone();
690
691 // Log available models for debugging
692 let available_count = registry.available_models(cx).count();
693 log::debug!("Total available models: {}", available_count);
694
695 let default_model = registry.default_model().and_then(|default_model| {
696 agent
697 .models
698 .model_from_id(&LanguageModels::model_id(&default_model.model))
699 });
700 let summarization_model = registry.thread_summary_model().map(|c| c.model);
701
702 let thread = cx.new(|cx| {
703 let mut thread = Thread::new(
704 project.clone(),
705 agent.project_context.clone(),
706 agent.context_server_registry.clone(),
707 action_log.clone(),
708 agent.templates.clone(),
709 default_model,
710 summarization_model,
711 cx,
712 );
713 thread.add_tool(CopyPathTool::new(project.clone()));
714 thread.add_tool(CreateDirectoryTool::new(project.clone()));
715 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
716 thread.add_tool(DiagnosticsTool::new(project.clone()));
717 thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
718 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
719 thread.add_tool(FindPathTool::new(project.clone()));
720 thread.add_tool(GrepTool::new(project.clone()));
721 thread.add_tool(ListDirectoryTool::new(project.clone()));
722 thread.add_tool(MovePathTool::new(project.clone()));
723 thread.add_tool(NowTool);
724 thread.add_tool(OpenTool::new(project.clone()));
725 thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
726 thread.add_tool(TerminalTool::new(project.clone(), cx));
727 thread.add_tool(ThinkingTool);
728 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
729 thread
730 });
731
732 Ok(thread)
733 },
734 )??;
735
736 let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
737 log::info!("Created session with ID: {}", session_id);
738 // Create AcpThread
739 let acp_thread = cx.update(|cx| {
740 cx.new(|_cx| {
741 acp_thread::AcpThread::new(
742 "agent2",
743 self.clone(),
744 project.clone(),
745 action_log.clone(),
746 session_id.clone(),
747 )
748 })
749 })?;
750
751 // Store the session
752 agent.update(cx, |agent, cx| {
753 agent.sessions.insert(
754 session_id,
755 Session {
756 thread,
757 acp_thread: acp_thread.downgrade(),
758 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
759 this.sessions.remove(acp_thread.session_id());
760 }),
761 },
762 );
763 })?;
764
765 Ok(acp_thread)
766 })
767 }
768
769 fn auth_methods(&self) -> &[acp::AuthMethod] {
770 &[] // No auth for in-process
771 }
772
773 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
774 Task::ready(Ok(()))
775 }
776
777 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
778 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
779 }
780
781 fn prompt(
782 &self,
783 id: Option<acp_thread::UserMessageId>,
784 params: acp::PromptRequest,
785 cx: &mut App,
786 ) -> Task<Result<acp::PromptResponse>> {
787 let id = id.expect("UserMessageId is required");
788 let session_id = params.session_id.clone();
789 log::info!("Received prompt request for session: {}", session_id);
790 log::debug!("Prompt blocks count: {}", params.prompt.len());
791
792 self.run_turn(session_id, cx, |thread, cx| {
793 let content: Vec<UserMessageContent> = params
794 .prompt
795 .into_iter()
796 .map(Into::into)
797 .collect::<Vec<_>>();
798 log::info!("Converted prompt to message: {} chars", content.len());
799 log::debug!("Message id: {:?}", id);
800 log::debug!("Message content: {:?}", content);
801
802 thread.update(cx, |thread, cx| thread.send(id, content, cx))
803 })
804 }
805
806 fn resume(
807 &self,
808 session_id: &acp::SessionId,
809 _cx: &mut App,
810 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
811 Some(Rc::new(NativeAgentSessionResume {
812 connection: self.clone(),
813 session_id: session_id.clone(),
814 }) as _)
815 }
816
817 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
818 log::info!("Cancelling on session: {}", session_id);
819 self.0.update(cx, |agent, cx| {
820 if let Some(agent) = agent.sessions.get(session_id) {
821 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
822 }
823 });
824 }
825
826 fn session_editor(
827 &self,
828 session_id: &agent_client_protocol::SessionId,
829 cx: &mut App,
830 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
831 self.0.update(cx, |agent, _cx| {
832 agent
833 .sessions
834 .get(session_id)
835 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
836 })
837 }
838
839 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
840 self
841 }
842}
843
844struct NativeAgentSessionEditor(Entity<Thread>);
845
846impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
847 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
848 Task::ready(
849 self.0
850 .update(cx, |thread, cx| thread.truncate(message_id, cx)),
851 )
852 }
853}
854
855struct NativeAgentSessionResume {
856 connection: NativeAgentConnection,
857 session_id: acp::SessionId,
858}
859
860impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
861 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
862 self.connection
863 .run_turn(self.session_id.clone(), cx, |thread, cx| {
864 thread.update(cx, |thread, cx| thread.resume(cx))
865 })
866 }
867}
868
869#[cfg(test)]
870mod tests {
871 use super::*;
872 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
873 use fs::FakeFs;
874 use gpui::TestAppContext;
875 use serde_json::json;
876 use settings::SettingsStore;
877
878 #[gpui::test]
879 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
880 init_test(cx);
881 let fs = FakeFs::new(cx.executor());
882 fs.insert_tree(
883 "/",
884 json!({
885 "a": {}
886 }),
887 )
888 .await;
889 let project = Project::test(fs.clone(), [], cx).await;
890 let agent = NativeAgent::new(
891 project.clone(),
892 Templates::new(),
893 None,
894 fs.clone(),
895 &mut cx.to_async(),
896 )
897 .await
898 .unwrap();
899 agent.read_with(cx, |agent, cx| {
900 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
901 });
902
903 let worktree = project
904 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
905 .await
906 .unwrap();
907 cx.run_until_parked();
908 agent.read_with(cx, |agent, cx| {
909 assert_eq!(
910 agent.project_context.read(cx).worktrees,
911 vec![WorktreeContext {
912 root_name: "a".into(),
913 abs_path: Path::new("/a").into(),
914 rules_file: None
915 }]
916 )
917 });
918
919 // Creating `/a/.rules` updates the project context.
920 fs.insert_file("/a/.rules", Vec::new()).await;
921 cx.run_until_parked();
922 agent.read_with(cx, |agent, cx| {
923 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
924 assert_eq!(
925 agent.project_context.read(cx).worktrees,
926 vec![WorktreeContext {
927 root_name: "a".into(),
928 abs_path: Path::new("/a").into(),
929 rules_file: Some(RulesFileContext {
930 path_in_worktree: Path::new(".rules").into(),
931 text: "".into(),
932 project_entry_id: rules_entry.id.to_usize()
933 })
934 }]
935 )
936 });
937 }
938
939 #[gpui::test]
940 async fn test_listing_models(cx: &mut TestAppContext) {
941 init_test(cx);
942 let fs = FakeFs::new(cx.executor());
943 fs.insert_tree("/", json!({ "a": {} })).await;
944 let project = Project::test(fs.clone(), [], cx).await;
945 let connection = NativeAgentConnection(
946 NativeAgent::new(
947 project.clone(),
948 Templates::new(),
949 None,
950 fs.clone(),
951 &mut cx.to_async(),
952 )
953 .await
954 .unwrap(),
955 );
956
957 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
958
959 let acp_thread::AgentModelList::Grouped(models) = models else {
960 panic!("Unexpected model group");
961 };
962 assert_eq!(
963 models,
964 IndexMap::from_iter([(
965 AgentModelGroupName("Fake".into()),
966 vec![AgentModelInfo {
967 id: AgentModelId("fake/fake".into()),
968 name: "Fake".into(),
969 icon: Some(ui::IconName::ZedAssistant),
970 }]
971 )])
972 );
973 }
974
975 #[gpui::test]
976 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
977 init_test(cx);
978 let fs = FakeFs::new(cx.executor());
979 fs.create_dir(paths::settings_file().parent().unwrap())
980 .await
981 .unwrap();
982 fs.insert_file(
983 paths::settings_file(),
984 json!({
985 "agent": {
986 "default_model": {
987 "provider": "foo",
988 "model": "bar"
989 }
990 }
991 })
992 .to_string()
993 .into_bytes(),
994 )
995 .await;
996 let project = Project::test(fs.clone(), [], cx).await;
997
998 // Create the agent and connection
999 let agent = NativeAgent::new(
1000 project.clone(),
1001 Templates::new(),
1002 None,
1003 fs.clone(),
1004 &mut cx.to_async(),
1005 )
1006 .await
1007 .unwrap();
1008 let connection = NativeAgentConnection(agent.clone());
1009
1010 // Create a thread/session
1011 let acp_thread = cx
1012 .update(|cx| {
1013 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1014 })
1015 .await
1016 .unwrap();
1017
1018 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1019
1020 // Select a model
1021 let model_id = AgentModelId("fake/fake".into());
1022 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1023 .await
1024 .unwrap();
1025
1026 // Verify the thread has the selected model
1027 agent.read_with(cx, |agent, _| {
1028 let session = agent.sessions.get(&session_id).unwrap();
1029 session.thread.read_with(cx, |thread, _| {
1030 assert_eq!(thread.model().unwrap().id().0, "fake");
1031 });
1032 });
1033
1034 cx.run_until_parked();
1035
1036 // Verify settings file was updated
1037 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1038 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1039
1040 // Check that the agent settings contain the selected model
1041 assert_eq!(
1042 settings_json["agent"]["default_model"]["model"],
1043 json!("fake")
1044 );
1045 assert_eq!(
1046 settings_json["agent"]["default_model"]["provider"],
1047 json!("fake")
1048 );
1049 }
1050
1051 fn init_test(cx: &mut TestAppContext) {
1052 env_logger::try_init().ok();
1053 cx.update(|cx| {
1054 let settings_store = SettingsStore::test(cx);
1055 cx.set_global(settings_store);
1056 Project::init_settings(cx);
1057 agent_settings::init(cx);
1058 language::init(cx);
1059 LanguageModelRegistry::test(cx);
1060 });
1061 }
1062}