1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
4 EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
6 UserMessageContent, WebSearchTool, templates::Templates,
7};
8use crate::{ThreadsDatabase, generate_session_id};
9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
10use agent_client_protocol as acp;
11use agent_settings::AgentSettings;
12use anyhow::{Context as _, Result, anyhow};
13use collections::{HashSet, IndexMap};
14use fs::Fs;
15use futures::channel::mpsc;
16use futures::{StreamExt, future};
17use gpui::{
18 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
19};
20use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
21use project::{Project, ProjectItem, ProjectPath, Worktree};
22use prompt_store::{
23 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
24};
25use settings::update_settings_file;
26use std::any::Any;
27use std::cell::RefCell;
28use std::collections::HashMap;
29use std::path::Path;
30use std::rc::Rc;
31use std::sync::Arc;
32use std::time::Duration;
33use util::ResultExt;
34
35const RULES_FILE_NAMES: [&'static str; 9] = [
36 ".rules",
37 ".cursorrules",
38 ".windsurfrules",
39 ".clinerules",
40 ".github/copilot-instructions.md",
41 "CLAUDE.md",
42 "AGENT.md",
43 "AGENTS.md",
44 "GEMINI.md",
45];
46
47const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
48
49pub struct RulesLoadingError {
50 pub message: SharedString,
51}
52
53/// Holds both the internal Thread and the AcpThread for a session
54struct Session {
55 /// The internal thread that processes messages
56 thread: Entity<Thread>,
57 /// The ACP thread that handles protocol communication
58 acp_thread: WeakEntity<acp_thread::AcpThread>,
59 save_task: Task<Result<()>>,
60 _subscriptions: Vec<Subscription>,
61}
62
63pub struct LanguageModels {
64 /// Access language model by ID
65 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
66 /// Cached list for returning language model information
67 model_list: acp_thread::AgentModelList,
68 refresh_models_rx: watch::Receiver<()>,
69 refresh_models_tx: watch::Sender<()>,
70}
71
72impl LanguageModels {
73 fn new(cx: &App) -> Self {
74 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
75 let mut this = Self {
76 models: HashMap::default(),
77 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
78 refresh_models_rx,
79 refresh_models_tx,
80 };
81 this.refresh_list(cx);
82 this
83 }
84
85 fn refresh_list(&mut self, cx: &App) {
86 let providers = LanguageModelRegistry::global(cx)
87 .read(cx)
88 .providers()
89 .into_iter()
90 .filter(|provider| provider.is_authenticated(cx))
91 .collect::<Vec<_>>();
92
93 let mut language_model_list = IndexMap::default();
94 let mut recommended_models = HashSet::default();
95
96 let mut recommended = Vec::new();
97 for provider in &providers {
98 for model in provider.recommended_models(cx) {
99 recommended_models.insert(model.id());
100 recommended.push(Self::map_language_model_to_info(&model, &provider));
101 }
102 }
103 if !recommended.is_empty() {
104 language_model_list.insert(
105 acp_thread::AgentModelGroupName("Recommended".into()),
106 recommended,
107 );
108 }
109
110 let mut models = HashMap::default();
111 for provider in providers {
112 let mut provider_models = Vec::new();
113 for model in provider.provided_models(cx) {
114 let model_info = Self::map_language_model_to_info(&model, &provider);
115 let model_id = model_info.id.clone();
116 if !recommended_models.contains(&model.id()) {
117 provider_models.push(model_info);
118 }
119 models.insert(model_id, model);
120 }
121 if !provider_models.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName(provider.name().0.clone()),
124 provider_models,
125 );
126 }
127 }
128
129 self.models = models;
130 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
131 self.refresh_models_tx.send(()).ok();
132 }
133
134 fn watch(&self) -> watch::Receiver<()> {
135 self.refresh_models_rx.clone()
136 }
137
138 pub fn model_from_id(
139 &self,
140 model_id: &acp_thread::AgentModelId,
141 ) -> Option<Arc<dyn LanguageModel>> {
142 self.models.get(model_id).cloned()
143 }
144
145 fn map_language_model_to_info(
146 model: &Arc<dyn LanguageModel>,
147 provider: &Arc<dyn LanguageModelProvider>,
148 ) -> acp_thread::AgentModelInfo {
149 acp_thread::AgentModelInfo {
150 id: Self::model_id(model),
151 name: model.name().0,
152 icon: Some(provider.icon()),
153 }
154 }
155
156 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
157 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
158 }
159}
160
161pub struct NativeAgent {
162 /// Session ID -> Session mapping
163 sessions: HashMap<acp::SessionId, Session>,
164 /// Shared project context for all threads
165 project_context: Rc<RefCell<ProjectContext>>,
166 project_context_needs_refresh: watch::Sender<()>,
167 _maintain_project_context: Task<Result<()>>,
168 context_server_registry: Entity<ContextServerRegistry>,
169 /// Shared templates for all threads
170 templates: Arc<Templates>,
171 /// Cached model information
172 models: LanguageModels,
173 project: Entity<Project>,
174 prompt_store: Option<Entity<PromptStore>>,
175 thread_database: Arc<ThreadsDatabase>,
176 history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
177 load_history: Task<()>,
178 fs: Arc<dyn Fs>,
179 _subscriptions: Vec<Subscription>,
180}
181
182impl NativeAgent {
183 pub async fn new(
184 project: Entity<Project>,
185 templates: Arc<Templates>,
186 prompt_store: Option<Entity<PromptStore>>,
187 fs: Arc<dyn Fs>,
188 cx: &mut AsyncApp,
189 ) -> Result<Entity<NativeAgent>> {
190 log::info!("Creating new NativeAgent");
191
192 let project_context = cx
193 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
194 .await;
195
196 let thread_database = cx
197 .update(|cx| ThreadsDatabase::connect(cx))?
198 .await
199 .map_err(|e| anyhow!(e))?;
200
201 cx.new(|cx| {
202 let mut subscriptions = vec![
203 cx.subscribe(&project, Self::handle_project_event),
204 cx.subscribe(
205 &LanguageModelRegistry::global(cx),
206 Self::handle_models_updated_event,
207 ),
208 ];
209 if let Some(prompt_store) = prompt_store.as_ref() {
210 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
211 }
212
213 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
214 watch::channel(());
215 let mut this = Self {
216 sessions: HashMap::new(),
217 project_context: Rc::new(RefCell::new(project_context)),
218 project_context_needs_refresh: project_context_needs_refresh_tx,
219 _maintain_project_context: cx.spawn(async move |this, cx| {
220 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
221 }),
222 context_server_registry: cx.new(|cx| {
223 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
224 }),
225 thread_database,
226 templates,
227 models: LanguageModels::new(cx),
228 project,
229 prompt_store,
230 fs,
231 history: watch::channel(None).0,
232 load_history: Task::ready(()),
233 _subscriptions: subscriptions,
234 };
235 this.reload_history(cx);
236 this
237 })
238 }
239
240 pub fn insert_session(
241 &mut self,
242 thread: Entity<Thread>,
243 acp_thread: Entity<AcpThread>,
244 cx: &mut Context<Self>,
245 ) {
246 let id = thread.read(cx).id().clone();
247 let weak_thread = acp_thread.downgrade();
248 self.sessions.insert(
249 id,
250 Session {
251 thread: thread.clone(),
252 acp_thread: weak_thread.clone(),
253 save_task: Task::ready(Ok(())),
254 _subscriptions: vec![
255 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
256 this.sessions.remove(acp_thread.session_id());
257 }),
258 cx.observe(&thread, move |this, thread, cx| {
259 if let Some(response_stream) =
260 thread.update(cx, |thread, cx| thread.generate_title_if_needed(cx))
261 {
262 NativeAgentConnection::handle_thread_events(
263 response_stream,
264 weak_thread.clone(),
265 cx,
266 )
267 .detach_and_log_err(cx);
268 }
269 this.save_thread(thread.clone(), cx)
270 }),
271 ],
272 },
273 );
274 }
275
276 fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
277 let thread = thread_handle.read(cx);
278 let id = thread.id().clone();
279 let Some(session) = self.sessions.get_mut(&id) else {
280 return;
281 };
282
283 let thread = thread_handle.downgrade();
284 let thread_database = self.thread_database.clone();
285 session.save_task = cx.spawn(async move |this, cx| {
286 cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
287
288 let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
289 thread_database.save_thread(id, db_thread).await?;
290 this.update(cx, |this, cx| this.reload_history(cx))?;
291 Ok(())
292 });
293 }
294
295 fn reload_history(&mut self, cx: &mut Context<Self>) {
296 let thread_database = self.thread_database.clone();
297 self.load_history = cx.spawn(async move |this, cx| {
298 let results = cx
299 .background_spawn(async move {
300 let results = thread_database.list_threads().await?;
301 anyhow::Ok(
302 results
303 .into_iter()
304 .map(|thread| AcpThreadMetadata {
305 agent: NATIVE_AGENT_SERVER_NAME.clone(),
306 id: thread.id.into(),
307 title: thread.title,
308 updated_at: thread.updated_at,
309 })
310 .collect(),
311 )
312 })
313 .await;
314 if let Some(results) = results.log_err() {
315 this.update(cx, |this, _| this.history.send(Some(results)))
316 .ok();
317 }
318 });
319 }
320
321 pub fn models(&self) -> &LanguageModels {
322 &self.models
323 }
324
325 async fn maintain_project_context(
326 this: WeakEntity<Self>,
327 mut needs_refresh: watch::Receiver<()>,
328 cx: &mut AsyncApp,
329 ) -> Result<()> {
330 while needs_refresh.changed().await.is_ok() {
331 let project_context = this
332 .update(cx, |this, cx| {
333 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
334 })?
335 .await;
336 this.update(cx, |this, _| this.project_context.replace(project_context))?;
337 }
338
339 Ok(())
340 }
341
342 fn build_project_context(
343 project: &Entity<Project>,
344 prompt_store: Option<&Entity<PromptStore>>,
345 cx: &mut App,
346 ) -> Task<ProjectContext> {
347 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
348 let worktree_tasks = worktrees
349 .into_iter()
350 .map(|worktree| {
351 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
352 })
353 .collect::<Vec<_>>();
354 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
355 prompt_store.read_with(cx, |prompt_store, cx| {
356 let prompts = prompt_store.default_prompt_metadata();
357 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
358 let contents = prompt_store.load(prompt_metadata.id, cx);
359 async move { (contents.await, prompt_metadata) }
360 });
361 cx.background_spawn(future::join_all(load_tasks))
362 })
363 } else {
364 Task::ready(vec![])
365 };
366
367 cx.spawn(async move |_cx| {
368 let (worktrees, default_user_rules) =
369 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
370
371 let worktrees = worktrees
372 .into_iter()
373 .map(|(worktree, _rules_error)| {
374 // TODO: show error message
375 // if let Some(rules_error) = rules_error {
376 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
377 // }
378 worktree
379 })
380 .collect::<Vec<_>>();
381
382 let default_user_rules = default_user_rules
383 .into_iter()
384 .flat_map(|(contents, prompt_metadata)| match contents {
385 Ok(contents) => Some(UserRulesContext {
386 uuid: match prompt_metadata.id {
387 PromptId::User { uuid } => uuid,
388 PromptId::EditWorkflow => return None,
389 },
390 title: prompt_metadata.title.map(|title| title.to_string()),
391 contents,
392 }),
393 Err(_err) => {
394 // TODO: show error message
395 // this.update(cx, |_, cx| {
396 // cx.emit(RulesLoadingError {
397 // message: format!("{err:?}").into(),
398 // });
399 // })
400 // .ok();
401 None
402 }
403 })
404 .collect::<Vec<_>>();
405
406 ProjectContext::new(worktrees, default_user_rules)
407 })
408 }
409
410 fn load_worktree_info_for_system_prompt(
411 worktree: Entity<Worktree>,
412 project: Entity<Project>,
413 cx: &mut App,
414 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
415 let tree = worktree.read(cx);
416 let root_name = tree.root_name().into();
417 let abs_path = tree.abs_path();
418
419 let mut context = WorktreeContext {
420 root_name,
421 abs_path,
422 rules_file: None,
423 };
424
425 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
426 let Some(rules_task) = rules_task else {
427 return Task::ready((context, None));
428 };
429
430 cx.spawn(async move |_| {
431 let (rules_file, rules_file_error) = match rules_task.await {
432 Ok(rules_file) => (Some(rules_file), None),
433 Err(err) => (
434 None,
435 Some(RulesLoadingError {
436 message: format!("{err}").into(),
437 }),
438 ),
439 };
440 context.rules_file = rules_file;
441 (context, rules_file_error)
442 })
443 }
444
445 fn load_worktree_rules_file(
446 worktree: Entity<Worktree>,
447 project: Entity<Project>,
448 cx: &mut App,
449 ) -> Option<Task<Result<RulesFileContext>>> {
450 let worktree = worktree.read(cx);
451 let worktree_id = worktree.id();
452 let selected_rules_file = RULES_FILE_NAMES
453 .into_iter()
454 .filter_map(|name| {
455 worktree
456 .entry_for_path(name)
457 .filter(|entry| entry.is_file())
458 .map(|entry| entry.path.clone())
459 })
460 .next();
461
462 // Note that Cline supports `.clinerules` being a directory, but that is not currently
463 // supported. This doesn't seem to occur often in GitHub repositories.
464 selected_rules_file.map(|path_in_worktree| {
465 let project_path = ProjectPath {
466 worktree_id,
467 path: path_in_worktree.clone(),
468 };
469 let buffer_task =
470 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
471 let rope_task = cx.spawn(async move |cx| {
472 buffer_task.await?.read_with(cx, |buffer, cx| {
473 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
474 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
475 })?
476 });
477 // Build a string from the rope on a background thread.
478 cx.background_spawn(async move {
479 let (project_entry_id, rope) = rope_task.await?;
480 anyhow::Ok(RulesFileContext {
481 path_in_worktree,
482 text: rope.to_string().trim().to_string(),
483 project_entry_id: project_entry_id.to_usize(),
484 })
485 })
486 })
487 }
488
489 fn handle_project_event(
490 &mut self,
491 _project: Entity<Project>,
492 event: &project::Event,
493 _cx: &mut Context<Self>,
494 ) {
495 match event {
496 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
497 self.project_context_needs_refresh.send(()).ok();
498 }
499 project::Event::WorktreeUpdatedEntries(_, items) => {
500 if items.iter().any(|(path, _, _)| {
501 RULES_FILE_NAMES
502 .iter()
503 .any(|name| path.as_ref() == Path::new(name))
504 }) {
505 self.project_context_needs_refresh.send(()).ok();
506 }
507 }
508 _ => {}
509 }
510 }
511
512 fn handle_prompts_updated_event(
513 &mut self,
514 _prompt_store: Entity<PromptStore>,
515 _event: &prompt_store::PromptsUpdatedEvent,
516 _cx: &mut Context<Self>,
517 ) {
518 self.project_context_needs_refresh.send(()).ok();
519 }
520
521 fn handle_models_updated_event(
522 &mut self,
523 registry: Entity<LanguageModelRegistry>,
524 _event: &language_model::Event,
525 cx: &mut Context<Self>,
526 ) {
527 self.models.refresh_list(cx);
528
529 let default_model = LanguageModelRegistry::read_global(cx)
530 .default_model()
531 .map(|m| m.model.clone());
532
533 for session in self.sessions.values_mut() {
534 session.thread.update(cx, |thread, cx| {
535 if thread.model().is_none()
536 && let Some(model) = default_model.clone()
537 {
538 thread.set_model(model, cx);
539 cx.notify();
540 }
541 let summarization_model = registry
542 .read(cx)
543 .thread_summary_model()
544 .map(|model| model.model.clone());
545 thread.set_summarization_model(summarization_model, cx);
546 });
547 }
548 }
549}
550
551/// Wrapper struct that implements the AgentConnection trait
552#[derive(Clone)]
553pub struct NativeAgentConnection(pub Entity<NativeAgent>);
554
555impl NativeAgentConnection {
556 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
557 self.0
558 .read(cx)
559 .sessions
560 .get(session_id)
561 .map(|session| session.thread.clone())
562 }
563
564 fn run_turn(
565 &self,
566 session_id: acp::SessionId,
567 cx: &mut App,
568 f: impl 'static
569 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
570 ) -> Task<Result<acp::PromptResponse>> {
571 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
572 agent
573 .sessions
574 .get_mut(&session_id)
575 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
576 }) else {
577 return Task::ready(Err(anyhow!("Session not found")));
578 };
579 log::debug!("Found session for: {}", session_id);
580
581 let response_stream = match f(thread, cx) {
582 Ok(stream) => stream,
583 Err(err) => return Task::ready(Err(err)),
584 };
585 Self::handle_thread_events(response_stream, acp_thread, cx)
586 }
587
588 fn handle_thread_events(
589 mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
590 acp_thread: WeakEntity<AcpThread>,
591 cx: &mut App,
592 ) -> Task<Result<acp::PromptResponse>> {
593 cx.spawn(async move |cx| {
594 // Handle response stream and forward to session.acp_thread
595 while let Some(result) = response_stream.next().await {
596 match result {
597 Ok(event) => {
598 log::trace!("Received completion event: {:?}", event);
599
600 match event {
601 ThreadEvent::UserMessage(message) => {
602 acp_thread.update(cx, |thread, cx| {
603 for content in message.content {
604 thread.push_user_content_block(
605 Some(message.id.clone()),
606 content.into(),
607 cx,
608 );
609 }
610 })?;
611 }
612 ThreadEvent::AgentText(text) => {
613 acp_thread.update(cx, |thread, cx| {
614 thread.push_assistant_content_block(
615 acp::ContentBlock::Text(acp::TextContent {
616 text,
617 annotations: None,
618 }),
619 false,
620 cx,
621 )
622 })?;
623 }
624 ThreadEvent::AgentThinking(text) => {
625 acp_thread.update(cx, |thread, cx| {
626 thread.push_assistant_content_block(
627 acp::ContentBlock::Text(acp::TextContent {
628 text,
629 annotations: None,
630 }),
631 true,
632 cx,
633 )
634 })?;
635 }
636 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
637 tool_call,
638 options,
639 response,
640 }) => {
641 let recv = acp_thread.update(cx, |thread, cx| {
642 thread.request_tool_call_authorization(tool_call, options, cx)
643 })?;
644 cx.background_spawn(async move {
645 if let Some(recv) = recv.log_err()
646 && let Some(option) = recv
647 .await
648 .context("authorization sender was dropped")
649 .log_err()
650 {
651 response
652 .send(option)
653 .map(|_| anyhow!("authorization receiver was dropped"))
654 .log_err();
655 }
656 })
657 .detach();
658 }
659 ThreadEvent::ToolCall(tool_call) => {
660 acp_thread.update(cx, |thread, cx| {
661 thread.upsert_tool_call(tool_call, cx)
662 })??;
663 }
664 ThreadEvent::ToolCallUpdate(update) => {
665 acp_thread.update(cx, |thread, cx| {
666 thread.update_tool_call(update, cx)
667 })??;
668 }
669 ThreadEvent::TitleUpdate(title) => {
670 dbg!("updating title");
671 acp_thread
672 .update(cx, |thread, cx| thread.update_title(title, cx))??;
673 }
674 ThreadEvent::Stop(stop_reason) => {
675 log::debug!("Assistant message complete: {:?}", stop_reason);
676 return Ok(acp::PromptResponse { stop_reason });
677 }
678 }
679 }
680 Err(e) => {
681 log::error!("Error in model response stream: {:?}", e);
682 return Err(e);
683 }
684 }
685 }
686
687 log::info!("Response stream completed");
688 anyhow::Ok(acp::PromptResponse {
689 stop_reason: acp::StopReason::EndTurn,
690 })
691 })
692 }
693
694 fn register_tools(
695 thread: &mut Thread,
696 project: Entity<Project>,
697 action_log: Entity<action_log::ActionLog>,
698 cx: &mut Context<Thread>,
699 ) {
700 let language_registry = project.read(cx).languages().clone();
701 thread.add_tool(CopyPathTool::new(project.clone()));
702 thread.add_tool(CreateDirectoryTool::new(project.clone()));
703 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
704 thread.add_tool(DiagnosticsTool::new(project.clone()));
705 thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
706 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
707 thread.add_tool(FindPathTool::new(project.clone()));
708 thread.add_tool(GrepTool::new(project.clone()));
709 thread.add_tool(ListDirectoryTool::new(project.clone()));
710 thread.add_tool(MovePathTool::new(project.clone()));
711 thread.add_tool(NowTool);
712 thread.add_tool(OpenTool::new(project.clone()));
713 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
714 thread.add_tool(TerminalTool::new(project.clone(), cx));
715 thread.add_tool(ThinkingTool);
716 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
717 }
718}
719
720impl AgentModelSelector for NativeAgentConnection {
721 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
722 log::debug!("NativeAgentConnection::list_models called");
723 let list = self.0.read(cx).models.model_list.clone();
724 Task::ready(if list.is_empty() {
725 Err(anyhow::anyhow!("No models available"))
726 } else {
727 Ok(list)
728 })
729 }
730
731 fn select_model(
732 &self,
733 session_id: acp::SessionId,
734 model_id: acp_thread::AgentModelId,
735 cx: &mut App,
736 ) -> Task<Result<()>> {
737 log::info!("Setting model for session {}: {}", session_id, model_id);
738 let Some(thread) = self
739 .0
740 .read(cx)
741 .sessions
742 .get(&session_id)
743 .map(|session| session.thread.clone())
744 else {
745 return Task::ready(Err(anyhow!("Session not found")));
746 };
747
748 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
749 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
750 };
751
752 thread.update(cx, |thread, cx| {
753 thread.set_model(model.clone(), cx);
754 });
755
756 update_settings_file::<AgentSettings>(
757 self.0.read(cx).fs.clone(),
758 cx,
759 move |settings, _cx| {
760 settings.set_model(model);
761 },
762 );
763
764 Task::ready(Ok(()))
765 }
766
767 fn selected_model(
768 &self,
769 session_id: &acp::SessionId,
770 cx: &mut App,
771 ) -> Task<Result<acp_thread::AgentModelInfo>> {
772 let session_id = session_id.clone();
773
774 let Some(thread) = self
775 .0
776 .read(cx)
777 .sessions
778 .get(&session_id)
779 .map(|session| session.thread.clone())
780 else {
781 return Task::ready(Err(anyhow!("Session not found")));
782 };
783 let Some(model) = thread.read(cx).model() else {
784 return Task::ready(Err(anyhow!("Model not found")));
785 };
786 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
787 else {
788 return Task::ready(Err(anyhow!("Provider not found")));
789 };
790 Task::ready(Ok(LanguageModels::map_language_model_to_info(
791 model, &provider,
792 )))
793 }
794
795 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
796 self.0.read(cx).models.watch()
797 }
798}
799
800impl acp_thread::AgentConnection for NativeAgentConnection {
801 fn new_thread(
802 self: Rc<Self>,
803 project: Entity<Project>,
804 cwd: &Path,
805 cx: &mut App,
806 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
807 let agent = self.0.clone();
808 log::info!("Creating new thread for project at: {:?}", cwd);
809
810 cx.spawn(async move |cx| {
811 log::debug!("Starting thread creation in async context");
812
813 // Generate session ID
814 let session_id = generate_session_id();
815 log::info!("Created session with ID: {}", session_id);
816
817 // Create AcpThread
818 let acp_thread = cx.update(|cx| {
819 cx.new(|cx| {
820 acp_thread::AcpThread::new(
821 "agent2",
822 self.clone(),
823 project.clone(),
824 session_id.clone(),
825 cx,
826 )
827 })
828 })?;
829 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
830
831 // Create Thread
832 let thread = agent.update(
833 cx,
834 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
835 // Fetch default model from registry settings
836 let registry = LanguageModelRegistry::read_global(cx);
837
838 // Log available models for debugging
839 let available_count = registry.available_models(cx).count();
840 log::debug!("Total available models: {}", available_count);
841
842 let default_model = registry.default_model().and_then(|default_model| {
843 agent
844 .models
845 .model_from_id(&LanguageModels::model_id(&default_model.model))
846 });
847
848 let summarization_model = registry.thread_summary_model().map(|c| c.model);
849
850 let thread = cx.new(|cx| {
851 let mut thread = Thread::new(
852 session_id.clone(),
853 project.clone(),
854 agent.project_context.clone(),
855 agent.context_server_registry.clone(),
856 action_log.clone(),
857 agent.templates.clone(),
858 default_model,
859 summarization_model,
860 cx,
861 );
862 Self::register_tools(&mut thread, project, action_log, cx);
863 thread
864 });
865
866 Ok(thread)
867 },
868 )??;
869
870 // Store the session
871 agent.update(cx, |agent, cx| {
872 agent.insert_session(thread, acp_thread.clone(), cx)
873 })?;
874
875 Ok(acp_thread)
876 })
877 }
878
879 fn auth_methods(&self) -> &[acp::AuthMethod] {
880 &[] // No auth for in-process
881 }
882
883 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
884 Task::ready(Ok(()))
885 }
886
887 fn list_threads(
888 &self,
889 cx: &mut App,
890 ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
891 Some(self.0.read(cx).history.receiver())
892 }
893
894 fn load_thread(
895 self: Rc<Self>,
896 project: Entity<Project>,
897 _cwd: &Path,
898 session_id: acp::SessionId,
899 cx: &mut App,
900 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
901 let database = self.0.update(cx, |this, _| this.thread_database.clone());
902 cx.spawn(async move |cx| {
903 let db_thread = database
904 .load_thread(session_id.clone())
905 .await?
906 .context("no such thread found")?;
907
908 let acp_thread = cx.update(|cx| {
909 cx.new(|cx| {
910 acp_thread::AcpThread::new(
911 db_thread.title.clone(),
912 self.clone(),
913 project.clone(),
914 session_id.clone(),
915 cx,
916 )
917 })
918 })?;
919 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
920 let agent = self.0.clone();
921
922 // Create Thread
923 let thread = agent.update(cx, |agent, cx| {
924 let language_model_registry = LanguageModelRegistry::global(cx);
925 let configured_model = language_model_registry
926 .update(cx, |registry, cx| {
927 db_thread
928 .model
929 .as_ref()
930 .and_then(|model| {
931 let model = SelectedModel {
932 provider: model.provider.clone().into(),
933 model: model.model.clone().into(),
934 };
935 registry.select_model(&model, cx)
936 })
937 .or_else(|| registry.default_model())
938 })
939 .context("no default model configured")?;
940
941 let model = agent
942 .models
943 .model_from_id(&LanguageModels::model_id(&configured_model.model))
944 .context("no model by id")?;
945
946 let summarization_model = language_model_registry
947 .read(cx)
948 .thread_summary_model()
949 .map(|c| c.model);
950
951 let thread = cx.new(|cx| {
952 let mut thread = Thread::from_db(
953 session_id,
954 db_thread,
955 project.clone(),
956 agent.project_context.clone(),
957 agent.context_server_registry.clone(),
958 action_log.clone(),
959 agent.templates.clone(),
960 model,
961 summarization_model,
962 cx,
963 );
964 Self::register_tools(&mut thread, project, action_log, cx);
965 thread
966 });
967
968 anyhow::Ok(thread)
969 })??;
970
971 // Store the session
972 agent.update(cx, |agent, cx| {
973 agent.insert_session(thread.clone(), acp_thread.clone(), cx)
974 })?;
975
976 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
977 cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
978 .await?;
979
980 Ok(acp_thread)
981 })
982 }
983
984 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
985 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
986 }
987
988 fn prompt(
989 &self,
990 id: Option<acp_thread::UserMessageId>,
991 params: acp::PromptRequest,
992 cx: &mut App,
993 ) -> Task<Result<acp::PromptResponse>> {
994 let id = id.expect("UserMessageId is required");
995 let session_id = params.session_id.clone();
996 log::info!("Received prompt request for session: {}", session_id);
997 log::debug!("Prompt blocks count: {}", params.prompt.len());
998
999 self.run_turn(session_id, cx, |thread, cx| {
1000 let content: Vec<UserMessageContent> = params
1001 .prompt
1002 .into_iter()
1003 .map(Into::into)
1004 .collect::<Vec<_>>();
1005 log::info!("Converted prompt to message: {} chars", content.len());
1006 log::debug!("Message id: {:?}", id);
1007 log::debug!("Message content: {:?}", content);
1008
1009 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1010 })
1011 }
1012
1013 fn resume(
1014 &self,
1015 session_id: &acp::SessionId,
1016 _cx: &mut App,
1017 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1018 Some(Rc::new(NativeAgentSessionResume {
1019 connection: self.clone(),
1020 session_id: session_id.clone(),
1021 }) as _)
1022 }
1023
1024 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1025 log::info!("Cancelling on session: {}", session_id);
1026 self.0.update(cx, |agent, cx| {
1027 if let Some(agent) = agent.sessions.get(session_id) {
1028 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1029 }
1030 });
1031 }
1032
1033 fn session_editor(
1034 &self,
1035 session_id: &agent_client_protocol::SessionId,
1036 cx: &mut App,
1037 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
1038 self.0.update(cx, |agent, _cx| {
1039 agent
1040 .sessions
1041 .get(session_id)
1042 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
1043 })
1044 }
1045
1046 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1047 self
1048 }
1049}
1050
1051struct NativeAgentSessionEditor(Entity<Thread>);
1052
1053impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
1054 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1055 Task::ready(
1056 self.0
1057 .update(cx, |thread, cx| thread.truncate(message_id, cx)),
1058 )
1059 }
1060}
1061
1062struct NativeAgentSessionResume {
1063 connection: NativeAgentConnection,
1064 session_id: acp::SessionId,
1065}
1066
1067impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1068 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1069 self.connection
1070 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1071 thread.update(cx, |thread, cx| thread.resume(cx))
1072 })
1073 }
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use crate::HistoryStore;
1079
1080 use super::*;
1081 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
1082 use fs::FakeFs;
1083 use gpui::TestAppContext;
1084 use language_model::fake_provider::FakeLanguageModel;
1085 use serde_json::json;
1086 use settings::SettingsStore;
1087 use util::path;
1088
1089 #[gpui::test]
1090 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1091 init_test(cx);
1092 let fs = FakeFs::new(cx.executor());
1093 fs.insert_tree(
1094 "/",
1095 json!({
1096 "a": {}
1097 }),
1098 )
1099 .await;
1100 let project = Project::test(fs.clone(), [], cx).await;
1101 let agent = NativeAgent::new(
1102 project.clone(),
1103 Templates::new(),
1104 None,
1105 fs.clone(),
1106 &mut cx.to_async(),
1107 )
1108 .await
1109 .unwrap();
1110 agent.read_with(cx, |agent, _| {
1111 assert_eq!(agent.project_context.borrow().worktrees, vec![])
1112 });
1113
1114 let worktree = project
1115 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1116 .await
1117 .unwrap();
1118 cx.run_until_parked();
1119 agent.read_with(cx, |agent, _| {
1120 assert_eq!(
1121 agent.project_context.borrow().worktrees,
1122 vec![WorktreeContext {
1123 root_name: "a".into(),
1124 abs_path: Path::new("/a").into(),
1125 rules_file: None
1126 }]
1127 )
1128 });
1129
1130 // Creating `/a/.rules` updates the project context.
1131 fs.insert_file("/a/.rules", Vec::new()).await;
1132 cx.run_until_parked();
1133 agent.read_with(cx, |agent, cx| {
1134 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1135 assert_eq!(
1136 agent.project_context.borrow().worktrees,
1137 vec![WorktreeContext {
1138 root_name: "a".into(),
1139 abs_path: Path::new("/a").into(),
1140 rules_file: Some(RulesFileContext {
1141 path_in_worktree: Path::new(".rules").into(),
1142 text: "".into(),
1143 project_entry_id: rules_entry.id.to_usize()
1144 })
1145 }]
1146 )
1147 });
1148 }
1149
1150 #[gpui::test]
1151 async fn test_listing_models(cx: &mut TestAppContext) {
1152 init_test(cx);
1153 let fs = FakeFs::new(cx.executor());
1154 fs.insert_tree("/", json!({ "a": {} })).await;
1155 let project = Project::test(fs.clone(), [], cx).await;
1156 let connection = NativeAgentConnection(
1157 NativeAgent::new(
1158 project.clone(),
1159 Templates::new(),
1160 None,
1161 fs.clone(),
1162 &mut cx.to_async(),
1163 )
1164 .await
1165 .unwrap(),
1166 );
1167
1168 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1169
1170 let acp_thread::AgentModelList::Grouped(models) = models else {
1171 panic!("Unexpected model group");
1172 };
1173 assert_eq!(
1174 models,
1175 IndexMap::from_iter([(
1176 AgentModelGroupName("Fake".into()),
1177 vec![AgentModelInfo {
1178 id: AgentModelId("fake/fake".into()),
1179 name: "Fake".into(),
1180 icon: Some(ui::IconName::ZedAssistant),
1181 }]
1182 )])
1183 );
1184 }
1185
1186 #[gpui::test]
1187 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1188 init_test(cx);
1189 let fs = FakeFs::new(cx.executor());
1190 fs.create_dir(paths::settings_file().parent().unwrap())
1191 .await
1192 .unwrap();
1193 fs.insert_file(
1194 paths::settings_file(),
1195 json!({
1196 "agent": {
1197 "default_model": {
1198 "provider": "foo",
1199 "model": "bar"
1200 }
1201 }
1202 })
1203 .to_string()
1204 .into_bytes(),
1205 )
1206 .await;
1207 let project = Project::test(fs.clone(), [], cx).await;
1208
1209 // Create the agent and connection
1210 let agent = NativeAgent::new(
1211 project.clone(),
1212 Templates::new(),
1213 None,
1214 fs.clone(),
1215 &mut cx.to_async(),
1216 )
1217 .await
1218 .unwrap();
1219 let connection = NativeAgentConnection(agent.clone());
1220
1221 // Create a thread/session
1222 let acp_thread = cx
1223 .update(|cx| {
1224 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1225 })
1226 .await
1227 .unwrap();
1228
1229 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1230
1231 // Select a model
1232 let model_id = AgentModelId("fake/fake".into());
1233 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1234 .await
1235 .unwrap();
1236
1237 // Verify the thread has the selected model
1238 agent.read_with(cx, |agent, _| {
1239 let session = agent.sessions.get(&session_id).unwrap();
1240 session.thread.read_with(cx, |thread, _| {
1241 assert_eq!(thread.model().unwrap().id().0, "fake");
1242 });
1243 });
1244
1245 cx.run_until_parked();
1246
1247 // Verify settings file was updated
1248 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1249 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1250
1251 // Check that the agent settings contain the selected model
1252 assert_eq!(
1253 settings_json["agent"]["default_model"]["model"],
1254 json!("fake")
1255 );
1256 assert_eq!(
1257 settings_json["agent"]["default_model"]["provider"],
1258 json!("fake")
1259 );
1260 }
1261
1262 #[gpui::test]
1263 async fn test_history(cx: &mut TestAppContext) {
1264 init_test(cx);
1265 let fs = FakeFs::new(cx.executor());
1266 let project = Project::test(fs.clone(), [], cx).await;
1267
1268 let agent = NativeAgent::new(
1269 project.clone(),
1270 Templates::new(),
1271 None,
1272 fs.clone(),
1273 &mut cx.to_async(),
1274 )
1275 .await
1276 .unwrap();
1277 let connection = NativeAgentConnection(agent.clone());
1278 let history_store = cx.new(|cx| {
1279 let mut store = HistoryStore::new(cx);
1280 store.register_agent(NATIVE_AGENT_SERVER_NAME.clone(), &connection, cx);
1281 store
1282 });
1283
1284 let acp_thread = cx
1285 .update(|cx| {
1286 Rc::new(connection.clone()).new_thread(project.clone(), Path::new(path!("")), cx)
1287 })
1288 .await
1289 .unwrap();
1290 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1291 let selector = connection.model_selector().unwrap();
1292
1293 let summarization_model: Arc<dyn LanguageModel> =
1294 Arc::new(FakeLanguageModel::default()) as _;
1295
1296 agent.update(cx, |agent, cx| {
1297 let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
1298 thread.update(cx, |thread, cx| {
1299 thread.set_summarization_model(Some(summarization_model.clone()), cx);
1300 })
1301 });
1302
1303 let model = cx
1304 .update(|cx| selector.selected_model(&session_id, cx))
1305 .await
1306 .expect("selected_model should succeed");
1307 let model = cx
1308 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1309 .unwrap();
1310 let model = model.as_fake();
1311
1312 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx));
1313 let send = cx.foreground_executor().spawn(send);
1314 cx.run_until_parked();
1315 model.send_last_completion_stream_text_chunk("Hey");
1316 model.end_last_completion_stream();
1317 send.await.unwrap();
1318
1319 summarization_model
1320 .as_fake()
1321 .send_last_completion_stream_text_chunk("Saying Hello");
1322 summarization_model.as_fake().end_last_completion_stream();
1323 cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
1324
1325 let history = history_store.update(cx, |store, cx| store.entries(cx));
1326 assert_eq!(history.len(), 1);
1327 assert_eq!(history[0].title(), "Saying Hello");
1328 }
1329
1330 fn init_test(cx: &mut TestAppContext) {
1331 env_logger::try_init().ok();
1332 cx.update(|cx| {
1333 let settings_store = SettingsStore::test(cx);
1334 cx.set_global(settings_store);
1335 Project::init_settings(cx);
1336 agent_settings::init(cx);
1337 language::init(cx);
1338 LanguageModelRegistry::test(cx);
1339 });
1340 }
1341}