1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
4 EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
6 UserMessageContent, WebSearchTool, templates::Templates,
7};
8use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id};
9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
10use agent_client_protocol as acp;
11use agent_settings::AgentSettings;
12use anyhow::{Context as _, Result, anyhow};
13use collections::{HashSet, IndexMap};
14use fs::Fs;
15use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
16use futures::future::Shared;
17use futures::{SinkExt, StreamExt, future};
18use gpui::{
19 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
20};
21use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
22use project::{Project, ProjectItem, ProjectPath, Worktree};
23use prompt_store::{
24 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
25};
26use settings::update_settings_file;
27use std::any::Any;
28use std::cell::RefCell;
29use std::collections::HashMap;
30use std::path::Path;
31use std::rc::Rc;
32use std::sync::Arc;
33use util::ResultExt;
34
35const RULES_FILE_NAMES: [&'static str; 9] = [
36 ".rules",
37 ".cursorrules",
38 ".windsurfrules",
39 ".clinerules",
40 ".github/copilot-instructions.md",
41 "CLAUDE.md",
42 "AGENT.md",
43 "AGENTS.md",
44 "GEMINI.md",
45];
46
47const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
48
49pub struct RulesLoadingError {
50 pub message: SharedString,
51}
52
53/// Holds both the internal Thread and the AcpThread for a session
54struct Session {
55 /// The internal thread that processes messages
56 thread: Entity<Thread>,
57 /// The ACP thread that handles protocol communication
58 acp_thread: WeakEntity<acp_thread::AcpThread>,
59 save_task: Task<Result<()>>,
60 _subscriptions: Vec<Subscription>,
61}
62
63pub struct LanguageModels {
64 /// Access language model by ID
65 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
66 /// Cached list for returning language model information
67 model_list: acp_thread::AgentModelList,
68 refresh_models_rx: watch::Receiver<()>,
69 refresh_models_tx: watch::Sender<()>,
70}
71
72impl LanguageModels {
73 fn new(cx: &App) -> Self {
74 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
75 let mut this = Self {
76 models: HashMap::default(),
77 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
78 refresh_models_rx,
79 refresh_models_tx,
80 };
81 this.refresh_list(cx);
82 this
83 }
84
85 fn refresh_list(&mut self, cx: &App) {
86 let providers = LanguageModelRegistry::global(cx)
87 .read(cx)
88 .providers()
89 .into_iter()
90 .filter(|provider| provider.is_authenticated(cx))
91 .collect::<Vec<_>>();
92
93 let mut language_model_list = IndexMap::default();
94 let mut recommended_models = HashSet::default();
95
96 let mut recommended = Vec::new();
97 for provider in &providers {
98 for model in provider.recommended_models(cx) {
99 recommended_models.insert(model.id());
100 recommended.push(Self::map_language_model_to_info(&model, &provider));
101 }
102 }
103 if !recommended.is_empty() {
104 language_model_list.insert(
105 acp_thread::AgentModelGroupName("Recommended".into()),
106 recommended,
107 );
108 }
109
110 let mut models = HashMap::default();
111 for provider in providers {
112 let mut provider_models = Vec::new();
113 for model in provider.provided_models(cx) {
114 let model_info = Self::map_language_model_to_info(&model, &provider);
115 let model_id = model_info.id.clone();
116 if !recommended_models.contains(&model.id()) {
117 provider_models.push(model_info);
118 }
119 models.insert(model_id, model);
120 }
121 if !provider_models.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName(provider.name().0.clone()),
124 provider_models,
125 );
126 }
127 }
128
129 self.models = models;
130 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
131 self.refresh_models_tx.send(()).ok();
132 }
133
134 fn watch(&self) -> watch::Receiver<()> {
135 self.refresh_models_rx.clone()
136 }
137
138 pub fn model_from_id(
139 &self,
140 model_id: &acp_thread::AgentModelId,
141 ) -> Option<Arc<dyn LanguageModel>> {
142 self.models.get(model_id).cloned()
143 }
144
145 fn map_language_model_to_info(
146 model: &Arc<dyn LanguageModel>,
147 provider: &Arc<dyn LanguageModelProvider>,
148 ) -> acp_thread::AgentModelInfo {
149 acp_thread::AgentModelInfo {
150 id: Self::model_id(model),
151 name: model.name().0,
152 icon: Some(provider.icon()),
153 }
154 }
155
156 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
157 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
158 }
159}
160
161pub struct NativeAgent {
162 /// Session ID -> Session mapping
163 sessions: HashMap<acp::SessionId, Session>,
164 /// Shared project context for all threads
165 project_context: Rc<RefCell<ProjectContext>>,
166 project_context_needs_refresh: watch::Sender<()>,
167 _maintain_project_context: Task<Result<()>>,
168 context_server_registry: Entity<ContextServerRegistry>,
169 /// Shared templates for all threads
170 templates: Arc<Templates>,
171 /// Cached model information
172 models: LanguageModels,
173 project: Entity<Project>,
174 prompt_store: Option<Entity<PromptStore>>,
175 thread_database: Arc<ThreadsDatabase>,
176 history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
177 load_history: Task<Result<()>>,
178 fs: Arc<dyn Fs>,
179 _subscriptions: Vec<Subscription>,
180}
181
182impl NativeAgent {
183 pub async fn new(
184 project: Entity<Project>,
185 templates: Arc<Templates>,
186 prompt_store: Option<Entity<PromptStore>>,
187 fs: Arc<dyn Fs>,
188 cx: &mut AsyncApp,
189 ) -> Result<Entity<NativeAgent>> {
190 log::info!("Creating new NativeAgent");
191
192 let project_context = cx
193 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
194 .await;
195
196 let thread_database = cx
197 .update(|cx| ThreadsDatabase::connect(cx))?
198 .await
199 .map_err(|e| anyhow!(e))?;
200
201 cx.new(|cx| {
202 let mut subscriptions = vec![
203 cx.subscribe(&project, Self::handle_project_event),
204 cx.subscribe(
205 &LanguageModelRegistry::global(cx),
206 Self::handle_models_updated_event,
207 ),
208 ];
209 if let Some(prompt_store) = prompt_store.as_ref() {
210 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
211 }
212
213 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
214 watch::channel(());
215 let this = Self {
216 sessions: HashMap::new(),
217 project_context: Rc::new(RefCell::new(project_context)),
218 project_context_needs_refresh: project_context_needs_refresh_tx,
219 _maintain_project_context: cx.spawn(async move |this, cx| {
220 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
221 }),
222 context_server_registry: cx.new(|cx| {
223 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
224 }),
225 thread_database,
226 templates,
227 models: LanguageModels::new(cx),
228 project,
229 prompt_store,
230 fs,
231 history: watch::channel(None).0,
232 load_history: Task::ready(Ok(())),
233 _subscriptions: subscriptions,
234 };
235 this.reload_history(cx);
236 this
237 })
238 }
239
240 pub fn insert_session(
241 &mut self,
242 thread: Entity<Thread>,
243 acp_thread: Entity<AcpThread>,
244 cx: &mut Context<Self>,
245 ) {
246 let id = thread.read(cx).id().clone();
247 self.sessions.insert(
248 id,
249 Session {
250 thread: thread.clone(),
251 acp_thread: acp_thread.downgrade(),
252 save_task: Task::ready(()),
253 _subscriptions: vec![
254 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
255 this.sessions.remove(acp_thread.session_id());
256 }),
257 cx.observe(&thread, |this, thread, cx| {
258 this.save_thread(thread.clone(), cx)
259 }),
260 ],
261 },
262 );
263 }
264
265 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
266 let id = thread.read(cx).id().clone();
267 let Some(session) = self.sessions.get_mut(&id) else {
268 return;
269 };
270
271 let thread = thread.downgrade();
272 let thread_database = self.thread_database.clone();
273 session.save_task = cx.spawn(async move |this, cx| {
274 cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
275 let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
276 thread_database.save_thread(id, db_thread).await?;
277 this.update(cx, |this, cx| this.reload_history(cx))?;
278 Ok(())
279 });
280 }
281
282 fn reload_history(&mut self, cx: &mut Context<Self>) {
283 let thread_database = self.thread_database.clone();
284 self.load_history = cx.spawn(async move |this, cx| {
285 let results = cx
286 .background_spawn(async move {
287 let results = thread_database.list_threads().await?;
288 Ok(results
289 .into_iter()
290 .map(|thread| AcpThreadMetadata {
291 agent: NATIVE_AGENT_SERVER_NAME.clone(),
292 id: thread.id.into(),
293 title: thread.title,
294 updated_at: thread.updated_at,
295 })
296 .collect())
297 })
298 .await?;
299 this.update(cx, |this, cx| this.history.send(Some(results)))?;
300 anyhow::Ok(())
301 });
302 }
303
304 pub fn models(&self) -> &LanguageModels {
305 &self.models
306 }
307
308 async fn maintain_project_context(
309 this: WeakEntity<Self>,
310 mut needs_refresh: watch::Receiver<()>,
311 cx: &mut AsyncApp,
312 ) -> Result<()> {
313 while needs_refresh.changed().await.is_ok() {
314 let project_context = this
315 .update(cx, |this, cx| {
316 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
317 })?
318 .await;
319 this.update(cx, |this, _| this.project_context.replace(project_context))?;
320 }
321
322 Ok(())
323 }
324
325 fn build_project_context(
326 project: &Entity<Project>,
327 prompt_store: Option<&Entity<PromptStore>>,
328 cx: &mut App,
329 ) -> Task<ProjectContext> {
330 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
331 let worktree_tasks = worktrees
332 .into_iter()
333 .map(|worktree| {
334 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
335 })
336 .collect::<Vec<_>>();
337 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
338 prompt_store.read_with(cx, |prompt_store, cx| {
339 let prompts = prompt_store.default_prompt_metadata();
340 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
341 let contents = prompt_store.load(prompt_metadata.id, cx);
342 async move { (contents.await, prompt_metadata) }
343 });
344 cx.background_spawn(future::join_all(load_tasks))
345 })
346 } else {
347 Task::ready(vec![])
348 };
349
350 cx.spawn(async move |_cx| {
351 let (worktrees, default_user_rules) =
352 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
353
354 let worktrees = worktrees
355 .into_iter()
356 .map(|(worktree, _rules_error)| {
357 // TODO: show error message
358 // if let Some(rules_error) = rules_error {
359 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
360 // }
361 worktree
362 })
363 .collect::<Vec<_>>();
364
365 let default_user_rules = default_user_rules
366 .into_iter()
367 .flat_map(|(contents, prompt_metadata)| match contents {
368 Ok(contents) => Some(UserRulesContext {
369 uuid: match prompt_metadata.id {
370 PromptId::User { uuid } => uuid,
371 PromptId::EditWorkflow => return None,
372 },
373 title: prompt_metadata.title.map(|title| title.to_string()),
374 contents,
375 }),
376 Err(_err) => {
377 // TODO: show error message
378 // this.update(cx, |_, cx| {
379 // cx.emit(RulesLoadingError {
380 // message: format!("{err:?}").into(),
381 // });
382 // })
383 // .ok();
384 None
385 }
386 })
387 .collect::<Vec<_>>();
388
389 ProjectContext::new(worktrees, default_user_rules)
390 })
391 }
392
393 fn load_worktree_info_for_system_prompt(
394 worktree: Entity<Worktree>,
395 project: Entity<Project>,
396 cx: &mut App,
397 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
398 let tree = worktree.read(cx);
399 let root_name = tree.root_name().into();
400 let abs_path = tree.abs_path();
401
402 let mut context = WorktreeContext {
403 root_name,
404 abs_path,
405 rules_file: None,
406 };
407
408 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
409 let Some(rules_task) = rules_task else {
410 return Task::ready((context, None));
411 };
412
413 cx.spawn(async move |_| {
414 let (rules_file, rules_file_error) = match rules_task.await {
415 Ok(rules_file) => (Some(rules_file), None),
416 Err(err) => (
417 None,
418 Some(RulesLoadingError {
419 message: format!("{err}").into(),
420 }),
421 ),
422 };
423 context.rules_file = rules_file;
424 (context, rules_file_error)
425 })
426 }
427
428 fn load_worktree_rules_file(
429 worktree: Entity<Worktree>,
430 project: Entity<Project>,
431 cx: &mut App,
432 ) -> Option<Task<Result<RulesFileContext>>> {
433 let worktree = worktree.read(cx);
434 let worktree_id = worktree.id();
435 let selected_rules_file = RULES_FILE_NAMES
436 .into_iter()
437 .filter_map(|name| {
438 worktree
439 .entry_for_path(name)
440 .filter(|entry| entry.is_file())
441 .map(|entry| entry.path.clone())
442 })
443 .next();
444
445 // Note that Cline supports `.clinerules` being a directory, but that is not currently
446 // supported. This doesn't seem to occur often in GitHub repositories.
447 selected_rules_file.map(|path_in_worktree| {
448 let project_path = ProjectPath {
449 worktree_id,
450 path: path_in_worktree.clone(),
451 };
452 let buffer_task =
453 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
454 let rope_task = cx.spawn(async move |cx| {
455 buffer_task.await?.read_with(cx, |buffer, cx| {
456 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
457 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
458 })?
459 });
460 // Build a string from the rope on a background thread.
461 cx.background_spawn(async move {
462 let (project_entry_id, rope) = rope_task.await?;
463 anyhow::Ok(RulesFileContext {
464 path_in_worktree,
465 text: rope.to_string().trim().to_string(),
466 project_entry_id: project_entry_id.to_usize(),
467 })
468 })
469 })
470 }
471
472 fn handle_project_event(
473 &mut self,
474 _project: Entity<Project>,
475 event: &project::Event,
476 _cx: &mut Context<Self>,
477 ) {
478 match event {
479 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
480 self.project_context_needs_refresh.send(()).ok();
481 }
482 project::Event::WorktreeUpdatedEntries(_, items) => {
483 if items.iter().any(|(path, _, _)| {
484 RULES_FILE_NAMES
485 .iter()
486 .any(|name| path.as_ref() == Path::new(name))
487 }) {
488 self.project_context_needs_refresh.send(()).ok();
489 }
490 }
491 _ => {}
492 }
493 }
494
495 fn handle_prompts_updated_event(
496 &mut self,
497 _prompt_store: Entity<PromptStore>,
498 _event: &prompt_store::PromptsUpdatedEvent,
499 _cx: &mut Context<Self>,
500 ) {
501 self.project_context_needs_refresh.send(()).ok();
502 }
503
504 fn handle_models_updated_event(
505 &mut self,
506 _registry: Entity<LanguageModelRegistry>,
507 _event: &language_model::Event,
508 cx: &mut Context<Self>,
509 ) {
510 self.models.refresh_list(cx);
511 for session in self.sessions.values_mut() {
512 session.thread.update(cx, |thread, _| {
513 let model_id = LanguageModels::model_id(&thread.model());
514 if let Some(model) = self.models.model_from_id(&model_id) {
515 thread.set_model(model.clone());
516 }
517 });
518 }
519 }
520}
521
522/// Wrapper struct that implements the AgentConnection trait
523#[derive(Clone)]
524pub struct NativeAgentConnection(pub Entity<NativeAgent>);
525
526impl NativeAgentConnection {
527 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
528 self.0
529 .read(cx)
530 .sessions
531 .get(session_id)
532 .map(|session| session.thread.clone())
533 }
534
535 fn run_turn(
536 &self,
537 session_id: acp::SessionId,
538 cx: &mut App,
539 f: impl 'static
540 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
541 ) -> Task<Result<acp::PromptResponse>> {
542 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
543 agent
544 .sessions
545 .get_mut(&session_id)
546 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
547 }) else {
548 return Task::ready(Err(anyhow!("Session not found")));
549 };
550 log::debug!("Found session for: {}", session_id);
551
552 let response_stream = match f(thread, cx) {
553 Ok(stream) => stream,
554 Err(err) => return Task::ready(Err(err)),
555 };
556 Self::handle_thread_events(response_stream, acp_thread, cx)
557 }
558
559 fn handle_thread_events(
560 mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
561 acp_thread: WeakEntity<AcpThread>,
562 cx: &mut App,
563 ) -> Task<Result<acp::PromptResponse>> {
564 cx.spawn(async move |cx| {
565 // Handle response stream and forward to session.acp_thread
566 while let Some(result) = response_stream.next().await {
567 match result {
568 Ok(event) => {
569 log::trace!("Received completion event: {:?}", event);
570
571 match event {
572 ThreadEvent::UserMessage(message) => {
573 acp_thread.update(cx, |thread, cx| {
574 for content in message.content {
575 thread.push_user_content_block(
576 Some(message.id.clone()),
577 content.into(),
578 cx,
579 );
580 }
581 })?;
582 }
583 ThreadEvent::AgentText(text) => {
584 acp_thread.update(cx, |thread, cx| {
585 thread.push_assistant_content_block(
586 acp::ContentBlock::Text(acp::TextContent {
587 text,
588 annotations: None,
589 }),
590 false,
591 cx,
592 )
593 })?;
594 }
595 ThreadEvent::AgentThinking(text) => {
596 acp_thread.update(cx, |thread, cx| {
597 thread.push_assistant_content_block(
598 acp::ContentBlock::Text(acp::TextContent {
599 text,
600 annotations: None,
601 }),
602 true,
603 cx,
604 )
605 })?;
606 }
607 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
608 tool_call,
609 options,
610 response,
611 }) => {
612 let recv = acp_thread.update(cx, |thread, cx| {
613 thread.request_tool_call_authorization(tool_call, options, cx)
614 })?;
615 cx.background_spawn(async move {
616 if let Some(recv) = recv.log_err()
617 && let Some(option) = recv
618 .await
619 .context("authorization sender was dropped")
620 .log_err()
621 {
622 response
623 .send(option)
624 .map(|_| anyhow!("authorization receiver was dropped"))
625 .log_err();
626 }
627 })
628 .detach();
629 }
630 ThreadEvent::ToolCall(tool_call) => {
631 acp_thread.update(cx, |thread, cx| {
632 thread.upsert_tool_call(tool_call, cx)
633 })??;
634 }
635 ThreadEvent::ToolCallUpdate(update) => {
636 acp_thread.update(cx, |thread, cx| {
637 thread.update_tool_call(update, cx)
638 })??;
639 }
640 ThreadEvent::Stop(stop_reason) => {
641 log::debug!("Assistant message complete: {:?}", stop_reason);
642 return Ok(acp::PromptResponse { stop_reason });
643 }
644 }
645 }
646 Err(e) => {
647 log::error!("Error in model response stream: {:?}", e);
648 return Err(e);
649 }
650 }
651 }
652
653 log::info!("Response stream completed");
654 anyhow::Ok(acp::PromptResponse {
655 stop_reason: acp::StopReason::EndTurn,
656 })
657 })
658 }
659
660 fn register_tools(
661 thread: &mut Thread,
662 project: Entity<Project>,
663 action_log: Entity<action_log::ActionLog>,
664 cx: &mut Context<Thread>,
665 ) {
666 let language_registry = project.read(cx).languages().clone();
667 thread.add_tool(CopyPathTool::new(project.clone()));
668 thread.add_tool(CreateDirectoryTool::new(project.clone()));
669 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
670 thread.add_tool(DiagnosticsTool::new(project.clone()));
671 thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
672 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
673 thread.add_tool(FindPathTool::new(project.clone()));
674 thread.add_tool(GrepTool::new(project.clone()));
675 thread.add_tool(ListDirectoryTool::new(project.clone()));
676 thread.add_tool(MovePathTool::new(project.clone()));
677 thread.add_tool(NowTool);
678 thread.add_tool(OpenTool::new(project.clone()));
679 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
680 thread.add_tool(TerminalTool::new(project.clone(), cx));
681 thread.add_tool(ThinkingTool);
682 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
683 }
684}
685
686impl AgentModelSelector for NativeAgentConnection {
687 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
688 log::debug!("NativeAgentConnection::list_models called");
689 let list = self.0.read(cx).models.model_list.clone();
690 Task::ready(if list.is_empty() {
691 Err(anyhow::anyhow!("No models available"))
692 } else {
693 Ok(list)
694 })
695 }
696
697 fn select_model(
698 &self,
699 session_id: acp::SessionId,
700 model_id: acp_thread::AgentModelId,
701 cx: &mut App,
702 ) -> Task<Result<()>> {
703 log::info!("Setting model for session {}: {}", session_id, model_id);
704 let Some(thread) = self
705 .0
706 .read(cx)
707 .sessions
708 .get(&session_id)
709 .map(|session| session.thread.clone())
710 else {
711 return Task::ready(Err(anyhow!("Session not found")));
712 };
713
714 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
715 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
716 };
717
718 thread.update(cx, |thread, _cx| {
719 thread.set_model(model.clone());
720 });
721
722 update_settings_file::<AgentSettings>(
723 self.0.read(cx).fs.clone(),
724 cx,
725 move |settings, _cx| {
726 settings.set_model(model);
727 },
728 );
729
730 Task::ready(Ok(()))
731 }
732
733 fn selected_model(
734 &self,
735 session_id: &acp::SessionId,
736 cx: &mut App,
737 ) -> Task<Result<acp_thread::AgentModelInfo>> {
738 let session_id = session_id.clone();
739
740 let Some(thread) = self
741 .0
742 .read(cx)
743 .sessions
744 .get(&session_id)
745 .map(|session| session.thread.clone())
746 else {
747 return Task::ready(Err(anyhow!("Session not found")));
748 };
749 let model = thread.read(cx).model().clone();
750 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
751 else {
752 return Task::ready(Err(anyhow!("Provider not found")));
753 };
754 Task::ready(Ok(LanguageModels::map_language_model_to_info(
755 &model, &provider,
756 )))
757 }
758
759 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
760 self.0.read(cx).models.watch()
761 }
762}
763
764impl acp_thread::AgentConnection for NativeAgentConnection {
765 fn new_thread(
766 self: Rc<Self>,
767 project: Entity<Project>,
768 cwd: &Path,
769 cx: &mut App,
770 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
771 let agent = self.0.clone();
772 log::info!("Creating new thread for project at: {:?}", cwd);
773
774 cx.spawn(async move |cx| {
775 log::debug!("Starting thread creation in async context");
776
777 // Generate session ID
778 let session_id = generate_session_id();
779 log::info!("Created session with ID: {}", session_id);
780
781 // Create AcpThread
782 let acp_thread = cx.update(|cx| {
783 cx.new(|cx| {
784 acp_thread::AcpThread::new(
785 "agent2",
786 self.clone(),
787 project.clone(),
788 session_id.clone(),
789 cx,
790 )
791 })
792 })?;
793 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
794
795 // Create Thread
796 let thread = agent.update(
797 cx,
798 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
799 // Fetch default model from registry settings
800 let registry = LanguageModelRegistry::read_global(cx);
801
802 // Log available models for debugging
803 let available_count = registry.available_models(cx).count();
804 log::debug!("Total available models: {}", available_count);
805
806 let default_model = registry
807 .default_model()
808 .and_then(|default_model| {
809 agent
810 .models
811 .model_from_id(&LanguageModels::model_id(&default_model.model))
812 })
813 .ok_or_else(|| {
814 log::warn!("No default model configured in settings");
815 anyhow!(
816 "No default model. Please configure a default model in settings."
817 )
818 })?;
819
820 let thread = cx.new(|cx| {
821 let mut thread = Thread::new(
822 session_id.clone(),
823 project.clone(),
824 agent.project_context.clone(),
825 agent.context_server_registry.clone(),
826 action_log.clone(),
827 agent.templates.clone(),
828 default_model,
829 cx,
830 );
831 Self::register_tools(&mut thread, project, action_log, cx);
832 thread
833 });
834
835 Ok(thread)
836 },
837 )??;
838
839 // Store the session
840 agent.update(cx, |agent, cx| {
841 agent.insert_session(thread, acp_thread.clone(), cx)
842 })?;
843
844 Ok(acp_thread)
845 })
846 }
847
848 fn auth_methods(&self) -> &[acp::AuthMethod] {
849 &[] // No auth for in-process
850 }
851
852 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
853 Task::ready(Ok(()))
854 }
855
856 fn list_threads(
857 &self,
858 cx: &mut App,
859 ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
860 Some(self.0.read(cx).history.receiver())
861 }
862
863 fn load_thread(
864 self: Rc<Self>,
865 project: Entity<Project>,
866 _cwd: &Path,
867 session_id: acp::SessionId,
868 cx: &mut App,
869 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
870 let thread_id = ThreadId::from(session_id.clone());
871 let database = self.0.update(cx, |this, _| this.thread_database.clone());
872 cx.spawn(async move |cx| {
873 let database = database.await.map_err(|e| anyhow!(e))?;
874 let db_thread = database
875 .load_thread(thread_id.clone())
876 .await?
877 .context("no such thread found")?;
878
879 let acp_thread = cx.update(|cx| {
880 cx.new(|cx| {
881 acp_thread::AcpThread::new(
882 db_thread.title.clone(),
883 self.clone(),
884 project.clone(),
885 session_id.clone(),
886 cx,
887 )
888 })
889 })?;
890 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
891 let agent = self.0.clone();
892
893 // Create Thread
894 let thread = agent.update(cx, |agent, cx| {
895 let configured_model = LanguageModelRegistry::global(cx)
896 .update(cx, |registry, cx| {
897 db_thread
898 .model
899 .as_ref()
900 .and_then(|model| {
901 let model = SelectedModel {
902 provider: model.provider.clone().into(),
903 model: model.model.clone().into(),
904 };
905 registry.select_model(&model, cx)
906 })
907 .or_else(|| registry.default_model())
908 })
909 .context("no default model configured")?;
910
911 let model = agent
912 .models
913 .model_from_id(&LanguageModels::model_id(&configured_model.model))
914 .context("no model by id")?;
915
916 let thread = cx.new(|cx| {
917 let mut thread = Thread::from_db(
918 thread_id,
919 db_thread,
920 project.clone(),
921 agent.project_context.clone(),
922 agent.context_server_registry.clone(),
923 action_log.clone(),
924 agent.templates.clone(),
925 model,
926 cx,
927 );
928 Self::register_tools(&mut thread, project, action_log, cx);
929 thread
930 });
931
932 anyhow::Ok(thread)
933 })??;
934
935 // Store the session
936 agent.update(cx, |agent, cx| {
937 agent.insert_session(session_id, thread, acp_thread, cx)
938 })?;
939
940 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
941 cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
942 .await?;
943
944 Ok(acp_thread)
945 })
946 }
947
948 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
949 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
950 }
951
952 fn prompt(
953 &self,
954 id: Option<acp_thread::UserMessageId>,
955 params: acp::PromptRequest,
956 cx: &mut App,
957 ) -> Task<Result<acp::PromptResponse>> {
958 let id = id.expect("UserMessageId is required");
959 let session_id = params.session_id.clone();
960 log::info!("Received prompt request for session: {}", session_id);
961 log::debug!("Prompt blocks count: {}", params.prompt.len());
962
963 self.run_turn(session_id, cx, |thread, cx| {
964 let content: Vec<UserMessageContent> = params
965 .prompt
966 .into_iter()
967 .map(Into::into)
968 .collect::<Vec<_>>();
969 log::info!("Converted prompt to message: {} chars", content.len());
970 log::debug!("Message id: {:?}", id);
971 log::debug!("Message content: {:?}", content);
972
973 Ok(thread.update(cx, |thread, cx| {
974 log::info!(
975 "Sending message to thread with model: {:?}",
976 thread.model().name()
977 );
978 thread.send(id, content, cx)
979 }))
980 })
981 }
982
983 fn resume(
984 &self,
985 session_id: &acp::SessionId,
986 _cx: &mut App,
987 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
988 Some(Rc::new(NativeAgentSessionResume {
989 connection: self.clone(),
990 session_id: session_id.clone(),
991 }) as _)
992 }
993
994 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
995 log::info!("Cancelling on session: {}", session_id);
996 self.0.update(cx, |agent, cx| {
997 if let Some(agent) = agent.sessions.get(session_id) {
998 agent.thread.update(cx, |thread, _cx| thread.cancel());
999 }
1000 });
1001 }
1002
1003 fn session_editor(
1004 &self,
1005 session_id: &agent_client_protocol::SessionId,
1006 cx: &mut App,
1007 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
1008 self.0.update(cx, |agent, _cx| {
1009 agent
1010 .sessions
1011 .get(session_id)
1012 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
1013 })
1014 }
1015
1016 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1017 self
1018 }
1019}
1020
1021struct NativeAgentSessionEditor(Entity<Thread>);
1022
1023impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
1024 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1025 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
1026 }
1027}
1028
1029struct NativeAgentSessionResume {
1030 connection: NativeAgentConnection,
1031 session_id: acp::SessionId,
1032}
1033
1034impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1035 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1036 self.connection
1037 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1038 thread.update(cx, |thread, cx| thread.resume(cx))
1039 })
1040 }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045 use super::*;
1046 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
1047 use fs::FakeFs;
1048 use gpui::TestAppContext;
1049 use serde_json::json;
1050 use settings::SettingsStore;
1051
1052 #[gpui::test]
1053 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1054 init_test(cx);
1055 let fs = FakeFs::new(cx.executor());
1056 fs.insert_tree(
1057 "/",
1058 json!({
1059 "a": {}
1060 }),
1061 )
1062 .await;
1063 let project = Project::test(fs.clone(), [], cx).await;
1064 let agent = NativeAgent::new(
1065 project.clone(),
1066 Templates::new(),
1067 None,
1068 fs.clone(),
1069 &mut cx.to_async(),
1070 )
1071 .await
1072 .unwrap();
1073 agent.read_with(cx, |agent, _| {
1074 assert_eq!(agent.project_context.borrow().worktrees, vec![])
1075 });
1076
1077 let worktree = project
1078 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1079 .await
1080 .unwrap();
1081 cx.run_until_parked();
1082 agent.read_with(cx, |agent, _| {
1083 assert_eq!(
1084 agent.project_context.borrow().worktrees,
1085 vec![WorktreeContext {
1086 root_name: "a".into(),
1087 abs_path: Path::new("/a").into(),
1088 rules_file: None
1089 }]
1090 )
1091 });
1092
1093 // Creating `/a/.rules` updates the project context.
1094 fs.insert_file("/a/.rules", Vec::new()).await;
1095 cx.run_until_parked();
1096 agent.read_with(cx, |agent, cx| {
1097 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1098 assert_eq!(
1099 agent.project_context.borrow().worktrees,
1100 vec![WorktreeContext {
1101 root_name: "a".into(),
1102 abs_path: Path::new("/a").into(),
1103 rules_file: Some(RulesFileContext {
1104 path_in_worktree: Path::new(".rules").into(),
1105 text: "".into(),
1106 project_entry_id: rules_entry.id.to_usize()
1107 })
1108 }]
1109 )
1110 });
1111 }
1112
1113 #[gpui::test]
1114 async fn test_listing_models(cx: &mut TestAppContext) {
1115 init_test(cx);
1116 let fs = FakeFs::new(cx.executor());
1117 fs.insert_tree("/", json!({ "a": {} })).await;
1118 let project = Project::test(fs.clone(), [], cx).await;
1119 let connection = NativeAgentConnection(
1120 NativeAgent::new(
1121 project.clone(),
1122 Templates::new(),
1123 None,
1124 fs.clone(),
1125 &mut cx.to_async(),
1126 )
1127 .await
1128 .unwrap(),
1129 );
1130
1131 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1132
1133 let acp_thread::AgentModelList::Grouped(models) = models else {
1134 panic!("Unexpected model group");
1135 };
1136 assert_eq!(
1137 models,
1138 IndexMap::from_iter([(
1139 AgentModelGroupName("Fake".into()),
1140 vec![AgentModelInfo {
1141 id: AgentModelId("fake/fake".into()),
1142 name: "Fake".into(),
1143 icon: Some(ui::IconName::ZedAssistant),
1144 }]
1145 )])
1146 );
1147 }
1148
1149 #[gpui::test]
1150 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1151 init_test(cx);
1152 let fs = FakeFs::new(cx.executor());
1153 fs.create_dir(paths::settings_file().parent().unwrap())
1154 .await
1155 .unwrap();
1156 fs.insert_file(
1157 paths::settings_file(),
1158 json!({
1159 "agent": {
1160 "default_model": {
1161 "provider": "foo",
1162 "model": "bar"
1163 }
1164 }
1165 })
1166 .to_string()
1167 .into_bytes(),
1168 )
1169 .await;
1170 let project = Project::test(fs.clone(), [], cx).await;
1171
1172 // Create the agent and connection
1173 let agent = NativeAgent::new(
1174 project.clone(),
1175 Templates::new(),
1176 None,
1177 fs.clone(),
1178 &mut cx.to_async(),
1179 )
1180 .await
1181 .unwrap();
1182 let connection = NativeAgentConnection(agent.clone());
1183
1184 // Create a thread/session
1185 let acp_thread = cx
1186 .update(|cx| {
1187 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1188 })
1189 .await
1190 .unwrap();
1191
1192 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1193
1194 // Select a model
1195 let model_id = AgentModelId("fake/fake".into());
1196 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1197 .await
1198 .unwrap();
1199
1200 // Verify the thread has the selected model
1201 agent.read_with(cx, |agent, _| {
1202 let session = agent.sessions.get(&session_id).unwrap();
1203 session.thread.read_with(cx, |thread, _| {
1204 assert_eq!(thread.model().id().0, "fake");
1205 });
1206 });
1207
1208 cx.run_until_parked();
1209
1210 // Verify settings file was updated
1211 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1212 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1213
1214 // Check that the agent settings contain the selected model
1215 assert_eq!(
1216 settings_json["agent"]["default_model"]["model"],
1217 json!("fake")
1218 );
1219 assert_eq!(
1220 settings_json["agent"]["default_model"]["provider"],
1221 json!("fake")
1222 );
1223 }
1224
1225 fn init_test(cx: &mut TestAppContext) {
1226 env_logger::try_init().ok();
1227 cx.update(|cx| {
1228 let settings_store = SettingsStore::test(cx);
1229 cx.set_global(settings_store);
1230 Project::init_settings(cx);
1231 agent_settings::init(cx);
1232 language::init(cx);
1233 LanguageModelRegistry::test(cx);
1234 });
1235 }
1236}