1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use futures::future::Shared;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::Settings as _;
16pub use terminal::*;
17
18use action_log::ActionLog;
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_system_shell};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
98 Self::Message {
99 block: ContentBlock::new(chunk.into(), language_registry, cx),
100 }
101 }
102
103 fn to_markdown(&self, cx: &App) -> String {
104 match self {
105 Self::Message { block } => block.to_markdown(cx).to_string(),
106 Self::Thought { block } => {
107 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
108 }
109 }
110 }
111}
112
113#[derive(Debug)]
114pub enum AgentThreadEntry {
115 UserMessage(UserMessage),
116 AssistantMessage(AssistantMessage),
117 ToolCall(ToolCall),
118}
119
120impl AgentThreadEntry {
121 pub fn to_markdown(&self, cx: &App) -> String {
122 match self {
123 Self::UserMessage(message) => message.to_markdown(cx),
124 Self::AssistantMessage(message) => message.to_markdown(cx),
125 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
126 }
127 }
128
129 pub fn user_message(&self) -> Option<&UserMessage> {
130 if let AgentThreadEntry::UserMessage(message) = self {
131 Some(message)
132 } else {
133 None
134 }
135 }
136
137 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.diffs())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
146 if let AgentThreadEntry::ToolCall(call) = self {
147 itertools::Either::Left(call.terminals())
148 } else {
149 itertools::Either::Right(std::iter::empty())
150 }
151 }
152
153 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
154 if let AgentThreadEntry::ToolCall(ToolCall {
155 locations,
156 resolved_locations,
157 ..
158 }) = self
159 {
160 Some((
161 locations.get(ix)?.clone(),
162 resolved_locations.get(ix)?.clone()?,
163 ))
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 pub content: Vec<ToolCallContent>,
176 pub status: ToolCallStatus,
177 pub locations: Vec<acp::ToolCallLocation>,
178 pub resolved_locations: Vec<Option<AgentLocation>>,
179 pub raw_input: Option<serde_json::Value>,
180 pub raw_output: Option<serde_json::Value>,
181}
182
183impl ToolCall {
184 fn from_acp(
185 tool_call: acp::ToolCall,
186 status: ToolCallStatus,
187 language_registry: Arc<LanguageRegistry>,
188 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
189 cx: &mut App,
190 ) -> Result<Self> {
191 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
192 first_line.to_owned() + "…"
193 } else {
194 tool_call.title
195 };
196 let mut content = Vec::with_capacity(tool_call.content.len());
197 for item in tool_call.content {
198 content.push(ToolCallContent::from_acp(
199 item,
200 language_registry.clone(),
201 terminals,
202 cx,
203 )?);
204 }
205
206 let result = Self {
207 id: tool_call.id,
208 label: cx
209 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
210 kind: tool_call.kind,
211 content,
212 locations: tool_call.locations,
213 resolved_locations: Vec::default(),
214 status,
215 raw_input: tool_call.raw_input,
216 raw_output: tool_call.raw_output,
217 };
218 Ok(result)
219 }
220
221 fn update_fields(
222 &mut self,
223 fields: acp::ToolCallUpdateFields,
224 language_registry: Arc<LanguageRegistry>,
225 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
226 cx: &mut App,
227 ) -> Result<()> {
228 let acp::ToolCallUpdateFields {
229 kind,
230 status,
231 title,
232 content,
233 locations,
234 raw_input,
235 raw_output,
236 } = fields;
237
238 if let Some(kind) = kind {
239 self.kind = kind;
240 }
241
242 if let Some(status) = status {
243 self.status = status.into();
244 }
245
246 if let Some(title) = title {
247 self.label.update(cx, |label, cx| {
248 if let Some((first_line, _)) = title.split_once("\n") {
249 label.replace(first_line.to_owned() + "…", cx)
250 } else {
251 label.replace(title, cx);
252 }
253 });
254 }
255
256 if let Some(content) = content {
257 let new_content_len = content.len();
258 let mut content = content.into_iter();
259
260 // Reuse existing content if we can
261 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
262 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
263 }
264 for new in content {
265 self.content.push(ToolCallContent::from_acp(
266 new,
267 language_registry.clone(),
268 terminals,
269 cx,
270 )?)
271 }
272 self.content.truncate(new_content_len);
273 }
274
275 if let Some(locations) = locations {
276 self.locations = locations;
277 }
278
279 if let Some(raw_input) = raw_input {
280 self.raw_input = Some(raw_input);
281 }
282
283 if let Some(raw_output) = raw_output {
284 if self.content.is_empty()
285 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
286 {
287 self.content
288 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
289 markdown,
290 }));
291 }
292 self.raw_output = Some(raw_output);
293 }
294 Ok(())
295 }
296
297 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
298 self.content.iter().filter_map(|content| match content {
299 ToolCallContent::Diff(diff) => Some(diff),
300 ToolCallContent::ContentBlock(_) => None,
301 ToolCallContent::Terminal(_) => None,
302 })
303 }
304
305 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
306 self.content.iter().filter_map(|content| match content {
307 ToolCallContent::Terminal(terminal) => Some(terminal),
308 ToolCallContent::ContentBlock(_) => None,
309 ToolCallContent::Diff(_) => None,
310 })
311 }
312
313 fn to_markdown(&self, cx: &App) -> String {
314 let mut markdown = format!(
315 "**Tool Call: {}**\nStatus: {}\n\n",
316 self.label.read(cx).source(),
317 self.status
318 );
319 for content in &self.content {
320 markdown.push_str(content.to_markdown(cx).as_str());
321 markdown.push_str("\n\n");
322 }
323 markdown
324 }
325
326 async fn resolve_location(
327 location: acp::ToolCallLocation,
328 project: WeakEntity<Project>,
329 cx: &mut AsyncApp,
330 ) -> Option<AgentLocation> {
331 let buffer = project
332 .update(cx, |project, cx| {
333 project
334 .project_path_for_absolute_path(&location.path, cx)
335 .map(|path| project.open_buffer(path, cx))
336 })
337 .ok()??;
338 let buffer = buffer.await.log_err()?;
339 let position = buffer
340 .update(cx, |buffer, _| {
341 if let Some(row) = location.line {
342 let snapshot = buffer.snapshot();
343 let column = snapshot.indent_size_for_line(row).len;
344 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
345 snapshot.anchor_before(point)
346 } else {
347 Anchor::MIN
348 }
349 })
350 .ok()?;
351
352 Some(AgentLocation {
353 buffer: buffer.downgrade(),
354 position,
355 })
356 }
357
358 fn resolve_locations(
359 &self,
360 project: Entity<Project>,
361 cx: &mut App,
362 ) -> Task<Vec<Option<AgentLocation>>> {
363 let locations = self.locations.clone();
364 project.update(cx, |_, cx| {
365 cx.spawn(async move |project, cx| {
366 let mut new_locations = Vec::new();
367 for location in locations {
368 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
369 }
370 new_locations
371 })
372 })
373 }
374}
375
376#[derive(Debug)]
377pub enum ToolCallStatus {
378 /// The tool call hasn't started running yet, but we start showing it to
379 /// the user.
380 Pending,
381 /// The tool call is waiting for confirmation from the user.
382 WaitingForConfirmation {
383 options: Vec<acp::PermissionOption>,
384 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
385 },
386 /// The tool call is currently running.
387 InProgress,
388 /// The tool call completed successfully.
389 Completed,
390 /// The tool call failed.
391 Failed,
392 /// The user rejected the tool call.
393 Rejected,
394 /// The user canceled generation so the tool call was canceled.
395 Canceled,
396}
397
398impl From<acp::ToolCallStatus> for ToolCallStatus {
399 fn from(status: acp::ToolCallStatus) -> Self {
400 match status {
401 acp::ToolCallStatus::Pending => Self::Pending,
402 acp::ToolCallStatus::InProgress => Self::InProgress,
403 acp::ToolCallStatus::Completed => Self::Completed,
404 acp::ToolCallStatus::Failed => Self::Failed,
405 }
406 }
407}
408
409impl Display for ToolCallStatus {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(
412 f,
413 "{}",
414 match self {
415 ToolCallStatus::Pending => "Pending",
416 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
417 ToolCallStatus::InProgress => "In Progress",
418 ToolCallStatus::Completed => "Completed",
419 ToolCallStatus::Failed => "Failed",
420 ToolCallStatus::Rejected => "Rejected",
421 ToolCallStatus::Canceled => "Canceled",
422 }
423 )
424 }
425}
426
427#[derive(Debug, PartialEq, Clone)]
428pub enum ContentBlock {
429 Empty,
430 Markdown { markdown: Entity<Markdown> },
431 ResourceLink { resource_link: acp::ResourceLink },
432}
433
434impl ContentBlock {
435 pub fn new(
436 block: acp::ContentBlock,
437 language_registry: &Arc<LanguageRegistry>,
438 cx: &mut App,
439 ) -> Self {
440 let mut this = Self::Empty;
441 this.append(block, language_registry, cx);
442 this
443 }
444
445 pub fn new_combined(
446 blocks: impl IntoIterator<Item = acp::ContentBlock>,
447 language_registry: Arc<LanguageRegistry>,
448 cx: &mut App,
449 ) -> Self {
450 let mut this = Self::Empty;
451 for block in blocks {
452 this.append(block, &language_registry, cx);
453 }
454 this
455 }
456
457 pub fn append(
458 &mut self,
459 block: acp::ContentBlock,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) {
463 if matches!(self, ContentBlock::Empty)
464 && let acp::ContentBlock::ResourceLink(resource_link) = block
465 {
466 *self = ContentBlock::ResourceLink { resource_link };
467 return;
468 }
469
470 let new_content = self.block_string_contents(block);
471
472 match self {
473 ContentBlock::Empty => {
474 *self = Self::create_markdown_block(new_content, language_registry, cx);
475 }
476 ContentBlock::Markdown { markdown } => {
477 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
478 }
479 ContentBlock::ResourceLink { resource_link } => {
480 let existing_content = Self::resource_link_md(&resource_link.uri);
481 let combined = format!("{}\n{}", existing_content, new_content);
482
483 *self = Self::create_markdown_block(combined, language_registry, cx);
484 }
485 }
486 }
487
488 fn create_markdown_block(
489 content: String,
490 language_registry: &Arc<LanguageRegistry>,
491 cx: &mut App,
492 ) -> ContentBlock {
493 ContentBlock::Markdown {
494 markdown: cx
495 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
496 }
497 }
498
499 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
500 match block {
501 acp::ContentBlock::Text(text_content) => text_content.text,
502 acp::ContentBlock::ResourceLink(resource_link) => {
503 Self::resource_link_md(&resource_link.uri)
504 }
505 acp::ContentBlock::Resource(acp::EmbeddedResource {
506 resource:
507 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
508 uri,
509 ..
510 }),
511 ..
512 }) => Self::resource_link_md(&uri),
513 acp::ContentBlock::Image(image) => Self::image_md(&image),
514 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
515 }
516 }
517
518 fn resource_link_md(uri: &str) -> String {
519 if let Some(uri) = MentionUri::parse(uri).log_err() {
520 uri.as_link().to_string()
521 } else {
522 uri.to_string()
523 }
524 }
525
526 fn image_md(_image: &acp::ImageContent) -> String {
527 "`Image`".into()
528 }
529
530 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
531 match self {
532 ContentBlock::Empty => "",
533 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
534 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
535 }
536 }
537
538 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
539 match self {
540 ContentBlock::Empty => None,
541 ContentBlock::Markdown { markdown } => Some(markdown),
542 ContentBlock::ResourceLink { .. } => None,
543 }
544 }
545
546 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
547 match self {
548 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
549 _ => None,
550 }
551 }
552}
553
554#[derive(Debug)]
555pub enum ToolCallContent {
556 ContentBlock(ContentBlock),
557 Diff(Entity<Diff>),
558 Terminal(Entity<Terminal>),
559}
560
561impl ToolCallContent {
562 pub fn from_acp(
563 content: acp::ToolCallContent,
564 language_registry: Arc<LanguageRegistry>,
565 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
566 cx: &mut App,
567 ) -> Result<Self> {
568 match content {
569 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
570 content,
571 &language_registry,
572 cx,
573 ))),
574 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
575 Diff::finalized(
576 diff.path,
577 diff.old_text,
578 diff.new_text,
579 language_registry,
580 cx,
581 )
582 }))),
583 acp::ToolCallContent::Terminal { terminal_id } => terminals
584 .get(&terminal_id)
585 .cloned()
586 .map(Self::Terminal)
587 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
588 }
589 }
590
591 pub fn update_from_acp(
592 &mut self,
593 new: acp::ToolCallContent,
594 language_registry: Arc<LanguageRegistry>,
595 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
596 cx: &mut App,
597 ) -> Result<()> {
598 let needs_update = match (&self, &new) {
599 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
600 old_diff.read(cx).needs_update(
601 new_diff.old_text.as_deref().unwrap_or(""),
602 &new_diff.new_text,
603 cx,
604 )
605 }
606 _ => true,
607 };
608
609 if needs_update {
610 *self = Self::from_acp(new, language_registry, terminals, cx)?;
611 }
612 Ok(())
613 }
614
615 pub fn to_markdown(&self, cx: &App) -> String {
616 match self {
617 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
618 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
619 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
620 }
621 }
622}
623
624#[derive(Debug, PartialEq)]
625pub enum ToolCallUpdate {
626 UpdateFields(acp::ToolCallUpdate),
627 UpdateDiff(ToolCallUpdateDiff),
628 UpdateTerminal(ToolCallUpdateTerminal),
629}
630
631impl ToolCallUpdate {
632 fn id(&self) -> &acp::ToolCallId {
633 match self {
634 Self::UpdateFields(update) => &update.id,
635 Self::UpdateDiff(diff) => &diff.id,
636 Self::UpdateTerminal(terminal) => &terminal.id,
637 }
638 }
639}
640
641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
642 fn from(update: acp::ToolCallUpdate) -> Self {
643 Self::UpdateFields(update)
644 }
645}
646
647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
648 fn from(diff: ToolCallUpdateDiff) -> Self {
649 Self::UpdateDiff(diff)
650 }
651}
652
653#[derive(Debug, PartialEq)]
654pub struct ToolCallUpdateDiff {
655 pub id: acp::ToolCallId,
656 pub diff: Entity<Diff>,
657}
658
659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
660 fn from(terminal: ToolCallUpdateTerminal) -> Self {
661 Self::UpdateTerminal(terminal)
662 }
663}
664
665#[derive(Debug, PartialEq)]
666pub struct ToolCallUpdateTerminal {
667 pub id: acp::ToolCallId,
668 pub terminal: Entity<Terminal>,
669}
670
671#[derive(Debug, Default)]
672pub struct Plan {
673 pub entries: Vec<PlanEntry>,
674}
675
676#[derive(Debug)]
677pub struct PlanStats<'a> {
678 pub in_progress_entry: Option<&'a PlanEntry>,
679 pub pending: u32,
680 pub completed: u32,
681}
682
683impl Plan {
684 pub fn is_empty(&self) -> bool {
685 self.entries.is_empty()
686 }
687
688 pub fn stats(&self) -> PlanStats<'_> {
689 let mut stats = PlanStats {
690 in_progress_entry: None,
691 pending: 0,
692 completed: 0,
693 };
694
695 for entry in &self.entries {
696 match &entry.status {
697 acp::PlanEntryStatus::Pending => {
698 stats.pending += 1;
699 }
700 acp::PlanEntryStatus::InProgress => {
701 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
702 }
703 acp::PlanEntryStatus::Completed => {
704 stats.completed += 1;
705 }
706 }
707 }
708
709 stats
710 }
711}
712
713#[derive(Debug)]
714pub struct PlanEntry {
715 pub content: Entity<Markdown>,
716 pub priority: acp::PlanEntryPriority,
717 pub status: acp::PlanEntryStatus,
718}
719
720impl PlanEntry {
721 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
722 Self {
723 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
724 priority: entry.priority,
725 status: entry.status,
726 }
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
731pub struct TokenUsage {
732 pub max_tokens: u64,
733 pub used_tokens: u64,
734}
735
736impl TokenUsage {
737 pub fn ratio(&self) -> TokenUsageRatio {
738 #[cfg(debug_assertions)]
739 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
740 .unwrap_or("0.8".to_string())
741 .parse()
742 .unwrap();
743 #[cfg(not(debug_assertions))]
744 let warning_threshold: f32 = 0.8;
745
746 // When the maximum is unknown because there is no selected model,
747 // avoid showing the token limit warning.
748 if self.max_tokens == 0 {
749 TokenUsageRatio::Normal
750 } else if self.used_tokens >= self.max_tokens {
751 TokenUsageRatio::Exceeded
752 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
753 TokenUsageRatio::Warning
754 } else {
755 TokenUsageRatio::Normal
756 }
757 }
758}
759
760#[derive(Debug, Clone, PartialEq, Eq)]
761pub enum TokenUsageRatio {
762 Normal,
763 Warning,
764 Exceeded,
765}
766
767#[derive(Debug, Clone)]
768pub struct RetryStatus {
769 pub last_error: SharedString,
770 pub attempt: usize,
771 pub max_attempts: usize,
772 pub started_at: Instant,
773 pub duration: Duration,
774}
775
776pub struct AcpThread {
777 title: SharedString,
778 entries: Vec<AgentThreadEntry>,
779 plan: Plan,
780 project: Entity<Project>,
781 action_log: Entity<ActionLog>,
782 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
783 send_task: Option<Task<()>>,
784 connection: Rc<dyn AgentConnection>,
785 session_id: acp::SessionId,
786 token_usage: Option<TokenUsage>,
787 prompt_capabilities: acp::PromptCapabilities,
788 available_commands: Vec<acp::AvailableCommand>,
789 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
790 determine_shell: Shared<Task<String>>,
791 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
792}
793
794#[derive(Debug)]
795pub enum AcpThreadEvent {
796 NewEntry,
797 TitleUpdated,
798 TokenUsageUpdated,
799 EntryUpdated(usize),
800 EntriesRemoved(Range<usize>),
801 ToolAuthorizationRequired,
802 Retry(RetryStatus),
803 Stopped,
804 Error,
805 LoadError(LoadError),
806 PromptCapabilitiesUpdated,
807}
808
809impl EventEmitter<AcpThreadEvent> for AcpThread {}
810
811#[derive(PartialEq, Eq, Debug)]
812pub enum ThreadStatus {
813 Idle,
814 WaitingForToolConfirmation,
815 Generating,
816}
817
818#[derive(Debug, Clone)]
819pub enum LoadError {
820 Unsupported {
821 command: SharedString,
822 current_version: SharedString,
823 minimum_version: SharedString,
824 },
825 FailedToInstall(SharedString),
826 Exited {
827 status: ExitStatus,
828 },
829 Other(SharedString),
830}
831
832impl Display for LoadError {
833 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
834 match self {
835 LoadError::Unsupported {
836 command: path,
837 current_version,
838 minimum_version,
839 } => {
840 write!(
841 f,
842 "version {current_version} from {path} is not supported (need at least {minimum_version})"
843 )
844 }
845 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
846 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
847 LoadError::Other(msg) => write!(f, "{msg}"),
848 }
849 }
850}
851
852impl Error for LoadError {}
853
854impl AcpThread {
855 pub fn new(
856 title: impl Into<SharedString>,
857 connection: Rc<dyn AgentConnection>,
858 project: Entity<Project>,
859 action_log: Entity<ActionLog>,
860 session_id: acp::SessionId,
861 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
862 available_commands: Vec<acp::AvailableCommand>,
863 cx: &mut Context<Self>,
864 ) -> Self {
865 let prompt_capabilities = *prompt_capabilities_rx.borrow();
866 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
867 loop {
868 let caps = prompt_capabilities_rx.recv().await?;
869 this.update(cx, |this, cx| {
870 this.prompt_capabilities = caps;
871 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
872 })?;
873 }
874 });
875
876 let determine_shell = cx
877 .background_spawn(async move {
878 if cfg!(windows) {
879 return get_system_shell();
880 }
881
882 if which::which("bash").is_ok() {
883 "bash".into()
884 } else {
885 get_system_shell()
886 }
887 })
888 .shared();
889
890 Self {
891 action_log,
892 shared_buffers: Default::default(),
893 entries: Default::default(),
894 plan: Default::default(),
895 title: title.into(),
896 project,
897 send_task: None,
898 connection,
899 session_id,
900 token_usage: None,
901 prompt_capabilities,
902 available_commands,
903 _observe_prompt_capabilities: task,
904 terminals: HashMap::default(),
905 determine_shell,
906 }
907 }
908
909 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
910 self.prompt_capabilities
911 }
912
913 pub fn available_commands(&self) -> Vec<acp::AvailableCommand> {
914 self.available_commands.clone()
915 }
916
917 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
918 &self.connection
919 }
920
921 pub fn action_log(&self) -> &Entity<ActionLog> {
922 &self.action_log
923 }
924
925 pub fn project(&self) -> &Entity<Project> {
926 &self.project
927 }
928
929 pub fn title(&self) -> SharedString {
930 self.title.clone()
931 }
932
933 pub fn entries(&self) -> &[AgentThreadEntry] {
934 &self.entries
935 }
936
937 pub fn session_id(&self) -> &acp::SessionId {
938 &self.session_id
939 }
940
941 pub fn status(&self) -> ThreadStatus {
942 if self.send_task.is_some() {
943 if self.waiting_for_tool_confirmation() {
944 ThreadStatus::WaitingForToolConfirmation
945 } else {
946 ThreadStatus::Generating
947 }
948 } else {
949 ThreadStatus::Idle
950 }
951 }
952
953 pub fn token_usage(&self) -> Option<&TokenUsage> {
954 self.token_usage.as_ref()
955 }
956
957 pub fn has_pending_edit_tool_calls(&self) -> bool {
958 for entry in self.entries.iter().rev() {
959 match entry {
960 AgentThreadEntry::UserMessage(_) => return false,
961 AgentThreadEntry::ToolCall(
962 call @ ToolCall {
963 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
964 ..
965 },
966 ) if call.diffs().next().is_some() => {
967 return true;
968 }
969 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
970 }
971 }
972
973 false
974 }
975
976 pub fn used_tools_since_last_user_message(&self) -> bool {
977 for entry in self.entries.iter().rev() {
978 match entry {
979 AgentThreadEntry::UserMessage(..) => return false,
980 AgentThreadEntry::AssistantMessage(..) => continue,
981 AgentThreadEntry::ToolCall(..) => return true,
982 }
983 }
984
985 false
986 }
987
988 pub fn handle_session_update(
989 &mut self,
990 update: acp::SessionUpdate,
991 cx: &mut Context<Self>,
992 ) -> Result<(), acp::Error> {
993 match update {
994 acp::SessionUpdate::UserMessageChunk { content } => {
995 self.push_user_content_block(None, content, cx);
996 }
997 acp::SessionUpdate::AgentMessageChunk { content } => {
998 self.push_assistant_content_block(content, false, cx);
999 }
1000 acp::SessionUpdate::AgentThoughtChunk { content } => {
1001 self.push_assistant_content_block(content, true, cx);
1002 }
1003 acp::SessionUpdate::ToolCall(tool_call) => {
1004 self.upsert_tool_call(tool_call, cx)?;
1005 }
1006 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1007 self.update_tool_call(tool_call_update, cx)?;
1008 }
1009 acp::SessionUpdate::Plan(plan) => {
1010 self.update_plan(plan, cx);
1011 }
1012 }
1013 Ok(())
1014 }
1015
1016 pub fn push_user_content_block(
1017 &mut self,
1018 message_id: Option<UserMessageId>,
1019 chunk: acp::ContentBlock,
1020 cx: &mut Context<Self>,
1021 ) {
1022 let language_registry = self.project.read(cx).languages().clone();
1023 let entries_len = self.entries.len();
1024
1025 if let Some(last_entry) = self.entries.last_mut()
1026 && let AgentThreadEntry::UserMessage(UserMessage {
1027 id,
1028 content,
1029 chunks,
1030 ..
1031 }) = last_entry
1032 {
1033 *id = message_id.or(id.take());
1034 content.append(chunk.clone(), &language_registry, cx);
1035 chunks.push(chunk);
1036 let idx = entries_len - 1;
1037 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1038 } else {
1039 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1040 self.push_entry(
1041 AgentThreadEntry::UserMessage(UserMessage {
1042 id: message_id,
1043 content,
1044 chunks: vec![chunk],
1045 checkpoint: None,
1046 }),
1047 cx,
1048 );
1049 }
1050 }
1051
1052 pub fn push_assistant_content_block(
1053 &mut self,
1054 chunk: acp::ContentBlock,
1055 is_thought: bool,
1056 cx: &mut Context<Self>,
1057 ) {
1058 let language_registry = self.project.read(cx).languages().clone();
1059 let entries_len = self.entries.len();
1060 if let Some(last_entry) = self.entries.last_mut()
1061 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1062 {
1063 let idx = entries_len - 1;
1064 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1065 match (chunks.last_mut(), is_thought) {
1066 (Some(AssistantMessageChunk::Message { block }), false)
1067 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1068 block.append(chunk, &language_registry, cx)
1069 }
1070 _ => {
1071 let block = ContentBlock::new(chunk, &language_registry, cx);
1072 if is_thought {
1073 chunks.push(AssistantMessageChunk::Thought { block })
1074 } else {
1075 chunks.push(AssistantMessageChunk::Message { block })
1076 }
1077 }
1078 }
1079 } else {
1080 let block = ContentBlock::new(chunk, &language_registry, cx);
1081 let chunk = if is_thought {
1082 AssistantMessageChunk::Thought { block }
1083 } else {
1084 AssistantMessageChunk::Message { block }
1085 };
1086
1087 self.push_entry(
1088 AgentThreadEntry::AssistantMessage(AssistantMessage {
1089 chunks: vec![chunk],
1090 }),
1091 cx,
1092 );
1093 }
1094 }
1095
1096 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1097 self.entries.push(entry);
1098 cx.emit(AcpThreadEvent::NewEntry);
1099 }
1100
1101 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1102 self.connection.set_title(&self.session_id, cx).is_some()
1103 }
1104
1105 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1106 if title != self.title {
1107 self.title = title.clone();
1108 cx.emit(AcpThreadEvent::TitleUpdated);
1109 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1110 return set_title.run(title, cx);
1111 }
1112 }
1113 Task::ready(Ok(()))
1114 }
1115
1116 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1117 self.token_usage = usage;
1118 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1119 }
1120
1121 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1122 cx.emit(AcpThreadEvent::Retry(status));
1123 }
1124
1125 pub fn update_tool_call(
1126 &mut self,
1127 update: impl Into<ToolCallUpdate>,
1128 cx: &mut Context<Self>,
1129 ) -> Result<()> {
1130 let update = update.into();
1131 let languages = self.project.read(cx).languages().clone();
1132
1133 let ix = self
1134 .index_for_tool_call(update.id())
1135 .context("Tool call not found")?;
1136 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1137 unreachable!()
1138 };
1139
1140 match update {
1141 ToolCallUpdate::UpdateFields(update) => {
1142 let location_updated = update.fields.locations.is_some();
1143 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1144 if location_updated {
1145 self.resolve_locations(update.id, cx);
1146 }
1147 }
1148 ToolCallUpdate::UpdateDiff(update) => {
1149 call.content.clear();
1150 call.content.push(ToolCallContent::Diff(update.diff));
1151 }
1152 ToolCallUpdate::UpdateTerminal(update) => {
1153 call.content.clear();
1154 call.content
1155 .push(ToolCallContent::Terminal(update.terminal));
1156 }
1157 }
1158
1159 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1160
1161 Ok(())
1162 }
1163
1164 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1165 pub fn upsert_tool_call(
1166 &mut self,
1167 tool_call: acp::ToolCall,
1168 cx: &mut Context<Self>,
1169 ) -> Result<(), acp::Error> {
1170 let status = tool_call.status.into();
1171 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1172 }
1173
1174 /// Fails if id does not match an existing entry.
1175 pub fn upsert_tool_call_inner(
1176 &mut self,
1177 update: acp::ToolCallUpdate,
1178 status: ToolCallStatus,
1179 cx: &mut Context<Self>,
1180 ) -> Result<(), acp::Error> {
1181 let language_registry = self.project.read(cx).languages().clone();
1182 let id = update.id.clone();
1183
1184 if let Some(ix) = self.index_for_tool_call(&id) {
1185 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1186 unreachable!()
1187 };
1188
1189 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1190 call.status = status;
1191
1192 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1193 } else {
1194 let call = ToolCall::from_acp(
1195 update.try_into()?,
1196 status,
1197 language_registry,
1198 &self.terminals,
1199 cx,
1200 )?;
1201 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1202 };
1203
1204 self.resolve_locations(id, cx);
1205 Ok(())
1206 }
1207
1208 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1209 self.entries
1210 .iter()
1211 .enumerate()
1212 .rev()
1213 .find_map(|(index, entry)| {
1214 if let AgentThreadEntry::ToolCall(tool_call) = entry
1215 && &tool_call.id == id
1216 {
1217 Some(index)
1218 } else {
1219 None
1220 }
1221 })
1222 }
1223
1224 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1225 // The tool call we are looking for is typically the last one, or very close to the end.
1226 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1227 self.entries
1228 .iter_mut()
1229 .enumerate()
1230 .rev()
1231 .find_map(|(index, tool_call)| {
1232 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1233 && &tool_call.id == id
1234 {
1235 Some((index, tool_call))
1236 } else {
1237 None
1238 }
1239 })
1240 }
1241
1242 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1243 self.entries
1244 .iter()
1245 .enumerate()
1246 .rev()
1247 .find_map(|(index, tool_call)| {
1248 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1249 && &tool_call.id == id
1250 {
1251 Some((index, tool_call))
1252 } else {
1253 None
1254 }
1255 })
1256 }
1257
1258 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1259 let project = self.project.clone();
1260 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1261 return;
1262 };
1263 let task = tool_call.resolve_locations(project, cx);
1264 cx.spawn(async move |this, cx| {
1265 let resolved_locations = task.await;
1266 this.update(cx, |this, cx| {
1267 let project = this.project.clone();
1268 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1269 return;
1270 };
1271 if let Some(Some(location)) = resolved_locations.last() {
1272 project.update(cx, |project, cx| {
1273 if let Some(agent_location) = project.agent_location() {
1274 let should_ignore = agent_location.buffer == location.buffer
1275 && location
1276 .buffer
1277 .update(cx, |buffer, _| {
1278 let snapshot = buffer.snapshot();
1279 let old_position =
1280 agent_location.position.to_point(&snapshot);
1281 let new_position = location.position.to_point(&snapshot);
1282 // ignore this so that when we get updates from the edit tool
1283 // the position doesn't reset to the startof line
1284 old_position.row == new_position.row
1285 && old_position.column > new_position.column
1286 })
1287 .ok()
1288 .unwrap_or_default();
1289 if !should_ignore {
1290 project.set_agent_location(Some(location.clone()), cx);
1291 }
1292 }
1293 });
1294 }
1295 if tool_call.resolved_locations != resolved_locations {
1296 tool_call.resolved_locations = resolved_locations;
1297 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1298 }
1299 })
1300 })
1301 .detach();
1302 }
1303
1304 pub fn request_tool_call_authorization(
1305 &mut self,
1306 tool_call: acp::ToolCallUpdate,
1307 options: Vec<acp::PermissionOption>,
1308 cx: &mut Context<Self>,
1309 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1310 let (tx, rx) = oneshot::channel();
1311
1312 if AgentSettings::get_global(cx).always_allow_tool_actions {
1313 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1314 // some tools would (incorrectly) continue to auto-accept.
1315 if let Some(allow_once_option) = options.iter().find_map(|option| {
1316 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1317 Some(option.id.clone())
1318 } else {
1319 None
1320 }
1321 }) {
1322 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1323 return Ok(async {
1324 acp::RequestPermissionOutcome::Selected {
1325 option_id: allow_once_option,
1326 }
1327 }
1328 .boxed());
1329 }
1330 }
1331
1332 let status = ToolCallStatus::WaitingForConfirmation {
1333 options,
1334 respond_tx: tx,
1335 };
1336
1337 self.upsert_tool_call_inner(tool_call, status, cx)?;
1338 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1339
1340 let fut = async {
1341 match rx.await {
1342 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1343 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1344 }
1345 }
1346 .boxed();
1347
1348 Ok(fut)
1349 }
1350
1351 pub fn authorize_tool_call(
1352 &mut self,
1353 id: acp::ToolCallId,
1354 option_id: acp::PermissionOptionId,
1355 option_kind: acp::PermissionOptionKind,
1356 cx: &mut Context<Self>,
1357 ) {
1358 let Some((ix, call)) = self.tool_call_mut(&id) else {
1359 return;
1360 };
1361
1362 let new_status = match option_kind {
1363 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1364 ToolCallStatus::Rejected
1365 }
1366 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1367 ToolCallStatus::InProgress
1368 }
1369 };
1370
1371 let curr_status = mem::replace(&mut call.status, new_status);
1372
1373 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1374 respond_tx.send(option_id).log_err();
1375 } else if cfg!(debug_assertions) {
1376 panic!("tried to authorize an already authorized tool call");
1377 }
1378
1379 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1380 }
1381
1382 /// Returns true if the last turn is awaiting tool authorization
1383 pub fn waiting_for_tool_confirmation(&self) -> bool {
1384 for entry in self.entries.iter().rev() {
1385 match &entry {
1386 AgentThreadEntry::ToolCall(call) => match call.status {
1387 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1388 ToolCallStatus::Pending
1389 | ToolCallStatus::InProgress
1390 | ToolCallStatus::Completed
1391 | ToolCallStatus::Failed
1392 | ToolCallStatus::Rejected
1393 | ToolCallStatus::Canceled => continue,
1394 },
1395 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1396 // Reached the beginning of the turn
1397 return false;
1398 }
1399 }
1400 }
1401 false
1402 }
1403
1404 pub fn plan(&self) -> &Plan {
1405 &self.plan
1406 }
1407
1408 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1409 let new_entries_len = request.entries.len();
1410 let mut new_entries = request.entries.into_iter();
1411
1412 // Reuse existing markdown to prevent flickering
1413 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1414 let PlanEntry {
1415 content,
1416 priority,
1417 status,
1418 } = old;
1419 content.update(cx, |old, cx| {
1420 old.replace(new.content, cx);
1421 });
1422 *priority = new.priority;
1423 *status = new.status;
1424 }
1425 for new in new_entries {
1426 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1427 }
1428 self.plan.entries.truncate(new_entries_len);
1429
1430 cx.notify();
1431 }
1432
1433 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1434 self.plan
1435 .entries
1436 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1437 cx.notify();
1438 }
1439
1440 #[cfg(any(test, feature = "test-support"))]
1441 pub fn send_raw(
1442 &mut self,
1443 message: &str,
1444 cx: &mut Context<Self>,
1445 ) -> BoxFuture<'static, Result<()>> {
1446 self.send(
1447 vec![acp::ContentBlock::Text(acp::TextContent {
1448 text: message.to_string(),
1449 annotations: None,
1450 })],
1451 cx,
1452 )
1453 }
1454
1455 pub fn send(
1456 &mut self,
1457 message: Vec<acp::ContentBlock>,
1458 cx: &mut Context<Self>,
1459 ) -> BoxFuture<'static, Result<()>> {
1460 let block = ContentBlock::new_combined(
1461 message.clone(),
1462 self.project.read(cx).languages().clone(),
1463 cx,
1464 );
1465 let request = acp::PromptRequest {
1466 prompt: message.clone(),
1467 session_id: self.session_id.clone(),
1468 };
1469 let git_store = self.project.read(cx).git_store().clone();
1470
1471 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1472 Some(UserMessageId::new())
1473 } else {
1474 None
1475 };
1476
1477 self.run_turn(cx, async move |this, cx| {
1478 this.update(cx, |this, cx| {
1479 this.push_entry(
1480 AgentThreadEntry::UserMessage(UserMessage {
1481 id: message_id.clone(),
1482 content: block,
1483 chunks: message,
1484 checkpoint: None,
1485 }),
1486 cx,
1487 );
1488 })
1489 .ok();
1490
1491 let old_checkpoint = git_store
1492 .update(cx, |git, cx| git.checkpoint(cx))?
1493 .await
1494 .context("failed to get old checkpoint")
1495 .log_err();
1496 this.update(cx, |this, cx| {
1497 if let Some((_ix, message)) = this.last_user_message() {
1498 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1499 git_checkpoint,
1500 show: false,
1501 });
1502 }
1503 this.connection.prompt(message_id, request, cx)
1504 })?
1505 .await
1506 })
1507 }
1508
1509 pub fn can_resume(&self, cx: &App) -> bool {
1510 self.connection.resume(&self.session_id, cx).is_some()
1511 }
1512
1513 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1514 self.run_turn(cx, async move |this, cx| {
1515 this.update(cx, |this, cx| {
1516 this.connection
1517 .resume(&this.session_id, cx)
1518 .map(|resume| resume.run(cx))
1519 })?
1520 .context("resuming a session is not supported")?
1521 .await
1522 })
1523 }
1524
1525 fn run_turn(
1526 &mut self,
1527 cx: &mut Context<Self>,
1528 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1529 ) -> BoxFuture<'static, Result<()>> {
1530 self.clear_completed_plan_entries(cx);
1531
1532 let (tx, rx) = oneshot::channel();
1533 let cancel_task = self.cancel(cx);
1534
1535 self.send_task = Some(cx.spawn(async move |this, cx| {
1536 cancel_task.await;
1537 tx.send(f(this, cx).await).ok();
1538 }));
1539
1540 cx.spawn(async move |this, cx| {
1541 let response = rx.await;
1542
1543 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1544 .await?;
1545
1546 this.update(cx, |this, cx| {
1547 this.project
1548 .update(cx, |project, cx| project.set_agent_location(None, cx));
1549 match response {
1550 Ok(Err(e)) => {
1551 this.send_task.take();
1552 cx.emit(AcpThreadEvent::Error);
1553 Err(e)
1554 }
1555 result => {
1556 let canceled = matches!(
1557 result,
1558 Ok(Ok(acp::PromptResponse {
1559 stop_reason: acp::StopReason::Cancelled
1560 }))
1561 );
1562
1563 // We only take the task if the current prompt wasn't canceled.
1564 //
1565 // This prompt may have been canceled because another one was sent
1566 // while it was still generating. In these cases, dropping `send_task`
1567 // would cause the next generation to be canceled.
1568 if !canceled {
1569 this.send_task.take();
1570 }
1571
1572 // Truncate entries if the last prompt was refused.
1573 if let Ok(Ok(acp::PromptResponse {
1574 stop_reason: acp::StopReason::Refusal,
1575 })) = result
1576 && let Some((ix, _)) = this.last_user_message()
1577 {
1578 let range = ix..this.entries.len();
1579 this.entries.truncate(ix);
1580 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1581 }
1582
1583 cx.emit(AcpThreadEvent::Stopped);
1584 Ok(())
1585 }
1586 }
1587 })?
1588 })
1589 .boxed()
1590 }
1591
1592 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1593 let Some(send_task) = self.send_task.take() else {
1594 return Task::ready(());
1595 };
1596
1597 for entry in self.entries.iter_mut() {
1598 if let AgentThreadEntry::ToolCall(call) = entry {
1599 let cancel = matches!(
1600 call.status,
1601 ToolCallStatus::Pending
1602 | ToolCallStatus::WaitingForConfirmation { .. }
1603 | ToolCallStatus::InProgress
1604 );
1605
1606 if cancel {
1607 call.status = ToolCallStatus::Canceled;
1608 }
1609 }
1610 }
1611
1612 self.connection.cancel(&self.session_id, cx);
1613
1614 // Wait for the send task to complete
1615 cx.foreground_executor().spawn(send_task)
1616 }
1617
1618 /// Rewinds this thread to before the entry at `index`, removing it and all
1619 /// subsequent entries while reverting any changes made from that point.
1620 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1621 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1622 return Task::ready(Err(anyhow!("not supported")));
1623 };
1624 let Some(message) = self.user_message(&id) else {
1625 return Task::ready(Err(anyhow!("message not found")));
1626 };
1627
1628 let checkpoint = message
1629 .checkpoint
1630 .as_ref()
1631 .map(|c| c.git_checkpoint.clone());
1632
1633 let git_store = self.project.read(cx).git_store().clone();
1634 cx.spawn(async move |this, cx| {
1635 if let Some(checkpoint) = checkpoint {
1636 git_store
1637 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1638 .await?;
1639 }
1640
1641 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1642 this.update(cx, |this, cx| {
1643 if let Some((ix, _)) = this.user_message_mut(&id) {
1644 let range = ix..this.entries.len();
1645 this.entries.truncate(ix);
1646 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1647 }
1648 })
1649 })
1650 }
1651
1652 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1653 let git_store = self.project.read(cx).git_store().clone();
1654
1655 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1656 if let Some(checkpoint) = message.checkpoint.as_ref() {
1657 checkpoint.git_checkpoint.clone()
1658 } else {
1659 return Task::ready(Ok(()));
1660 }
1661 } else {
1662 return Task::ready(Ok(()));
1663 };
1664
1665 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1666 cx.spawn(async move |this, cx| {
1667 let new_checkpoint = new_checkpoint
1668 .await
1669 .context("failed to get new checkpoint")
1670 .log_err();
1671 if let Some(new_checkpoint) = new_checkpoint {
1672 let equal = git_store
1673 .update(cx, |git, cx| {
1674 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1675 })?
1676 .await
1677 .unwrap_or(true);
1678 this.update(cx, |this, cx| {
1679 let (ix, message) = this.last_user_message().context("no user message")?;
1680 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1681 checkpoint.show = !equal;
1682 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1683 anyhow::Ok(())
1684 })??;
1685 }
1686
1687 Ok(())
1688 })
1689 }
1690
1691 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1692 self.entries
1693 .iter_mut()
1694 .enumerate()
1695 .rev()
1696 .find_map(|(ix, entry)| {
1697 if let AgentThreadEntry::UserMessage(message) = entry {
1698 Some((ix, message))
1699 } else {
1700 None
1701 }
1702 })
1703 }
1704
1705 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1706 self.entries.iter().find_map(|entry| {
1707 if let AgentThreadEntry::UserMessage(message) = entry {
1708 if message.id.as_ref() == Some(id) {
1709 Some(message)
1710 } else {
1711 None
1712 }
1713 } else {
1714 None
1715 }
1716 })
1717 }
1718
1719 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1720 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1721 if let AgentThreadEntry::UserMessage(message) = entry {
1722 if message.id.as_ref() == Some(id) {
1723 Some((ix, message))
1724 } else {
1725 None
1726 }
1727 } else {
1728 None
1729 }
1730 })
1731 }
1732
1733 pub fn read_text_file(
1734 &self,
1735 path: PathBuf,
1736 line: Option<u32>,
1737 limit: Option<u32>,
1738 reuse_shared_snapshot: bool,
1739 cx: &mut Context<Self>,
1740 ) -> Task<Result<String>> {
1741 let project = self.project.clone();
1742 let action_log = self.action_log.clone();
1743 cx.spawn(async move |this, cx| {
1744 let load = project.update(cx, |project, cx| {
1745 let path = project
1746 .project_path_for_absolute_path(&path, cx)
1747 .context("invalid path")?;
1748 anyhow::Ok(project.open_buffer(path, cx))
1749 });
1750 let buffer = load??.await?;
1751
1752 let snapshot = if reuse_shared_snapshot {
1753 this.read_with(cx, |this, _| {
1754 this.shared_buffers.get(&buffer.clone()).cloned()
1755 })
1756 .log_err()
1757 .flatten()
1758 } else {
1759 None
1760 };
1761
1762 let snapshot = if let Some(snapshot) = snapshot {
1763 snapshot
1764 } else {
1765 action_log.update(cx, |action_log, cx| {
1766 action_log.buffer_read(buffer.clone(), cx);
1767 })?;
1768 project.update(cx, |project, cx| {
1769 let position = buffer
1770 .read(cx)
1771 .snapshot()
1772 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1773 project.set_agent_location(
1774 Some(AgentLocation {
1775 buffer: buffer.downgrade(),
1776 position,
1777 }),
1778 cx,
1779 );
1780 })?;
1781
1782 buffer.update(cx, |buffer, _| buffer.snapshot())?
1783 };
1784
1785 this.update(cx, |this, _| {
1786 let text = snapshot.text();
1787 this.shared_buffers.insert(buffer.clone(), snapshot);
1788 if line.is_none() && limit.is_none() {
1789 return Ok(text);
1790 }
1791 let limit = limit.unwrap_or(u32::MAX) as usize;
1792 let Some(line) = line else {
1793 return Ok(text.lines().take(limit).collect::<String>());
1794 };
1795
1796 let count = text.lines().count();
1797 if count < line as usize {
1798 anyhow::bail!("There are only {} lines", count);
1799 }
1800 Ok(text
1801 .lines()
1802 .skip(line as usize + 1)
1803 .take(limit)
1804 .collect::<String>())
1805 })?
1806 })
1807 }
1808
1809 pub fn write_text_file(
1810 &self,
1811 path: PathBuf,
1812 content: String,
1813 cx: &mut Context<Self>,
1814 ) -> Task<Result<()>> {
1815 let project = self.project.clone();
1816 let action_log = self.action_log.clone();
1817 cx.spawn(async move |this, cx| {
1818 let load = project.update(cx, |project, cx| {
1819 let path = project
1820 .project_path_for_absolute_path(&path, cx)
1821 .context("invalid path")?;
1822 anyhow::Ok(project.open_buffer(path, cx))
1823 });
1824 let buffer = load??.await?;
1825 let snapshot = this.update(cx, |this, cx| {
1826 this.shared_buffers
1827 .get(&buffer)
1828 .cloned()
1829 .unwrap_or_else(|| buffer.read(cx).snapshot())
1830 })?;
1831 let edits = cx
1832 .background_executor()
1833 .spawn(async move {
1834 let old_text = snapshot.text();
1835 text_diff(old_text.as_str(), &content)
1836 .into_iter()
1837 .map(|(range, replacement)| {
1838 (
1839 snapshot.anchor_after(range.start)
1840 ..snapshot.anchor_before(range.end),
1841 replacement,
1842 )
1843 })
1844 .collect::<Vec<_>>()
1845 })
1846 .await;
1847
1848 project.update(cx, |project, cx| {
1849 project.set_agent_location(
1850 Some(AgentLocation {
1851 buffer: buffer.downgrade(),
1852 position: edits
1853 .last()
1854 .map(|(range, _)| range.end)
1855 .unwrap_or(Anchor::MIN),
1856 }),
1857 cx,
1858 );
1859 })?;
1860
1861 let format_on_save = cx.update(|cx| {
1862 action_log.update(cx, |action_log, cx| {
1863 action_log.buffer_read(buffer.clone(), cx);
1864 });
1865
1866 let format_on_save = buffer.update(cx, |buffer, cx| {
1867 buffer.edit(edits, None, cx);
1868
1869 let settings = language::language_settings::language_settings(
1870 buffer.language().map(|l| l.name()),
1871 buffer.file(),
1872 cx,
1873 );
1874
1875 settings.format_on_save != FormatOnSave::Off
1876 });
1877 action_log.update(cx, |action_log, cx| {
1878 action_log.buffer_edited(buffer.clone(), cx);
1879 });
1880 format_on_save
1881 })?;
1882
1883 if format_on_save {
1884 let format_task = project.update(cx, |project, cx| {
1885 project.format(
1886 HashSet::from_iter([buffer.clone()]),
1887 LspFormatTarget::Buffers,
1888 false,
1889 FormatTrigger::Save,
1890 cx,
1891 )
1892 })?;
1893 format_task.await.log_err();
1894
1895 action_log.update(cx, |action_log, cx| {
1896 action_log.buffer_edited(buffer.clone(), cx);
1897 })?;
1898 }
1899
1900 project
1901 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1902 .await
1903 })
1904 }
1905
1906 pub fn create_terminal(
1907 &self,
1908 mut command: String,
1909 args: Vec<String>,
1910 extra_env: Vec<acp::EnvVariable>,
1911 cwd: Option<PathBuf>,
1912 output_byte_limit: Option<u64>,
1913 cx: &mut Context<Self>,
1914 ) -> Task<Result<Entity<Terminal>>> {
1915 for arg in args {
1916 command.push(' ');
1917 command.push_str(&arg);
1918 }
1919
1920 let shell_command = if cfg!(windows) {
1921 format!("$null | & {{{}}}", command.replace("\"", "'"))
1922 } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1923 // Make sure once we're *inside* the shell, we cd into `cwd`
1924 format!("(cd {cwd}; {}) </dev/null", command)
1925 } else {
1926 format!("({}) </dev/null", command)
1927 };
1928 let args = vec!["-c".into(), shell_command];
1929
1930 let env = match &cwd {
1931 Some(dir) => self.project.update(cx, |project, cx| {
1932 project.directory_environment(dir.as_path().into(), cx)
1933 }),
1934 None => Task::ready(None).shared(),
1935 };
1936
1937 let env = cx.spawn(async move |_, _| {
1938 let mut env = env.await.unwrap_or_default();
1939 if cfg!(unix) {
1940 env.insert("PAGER".into(), "cat".into());
1941 }
1942 for var in extra_env {
1943 env.insert(var.name, var.value);
1944 }
1945 env
1946 });
1947
1948 let project = self.project.clone();
1949 let language_registry = project.read(cx).languages().clone();
1950 let determine_shell = self.determine_shell.clone();
1951
1952 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1953 let terminal_task = cx.spawn({
1954 let terminal_id = terminal_id.clone();
1955 async move |_this, cx| {
1956 let program = determine_shell.await;
1957 let env = env.await;
1958 let terminal = project
1959 .update(cx, |project, cx| {
1960 project.create_terminal_task(
1961 task::SpawnInTerminal {
1962 command: Some(program),
1963 args,
1964 cwd: cwd.clone(),
1965 env,
1966 ..Default::default()
1967 },
1968 cx,
1969 )
1970 })?
1971 .await?;
1972
1973 cx.new(|cx| {
1974 Terminal::new(
1975 terminal_id,
1976 command,
1977 cwd,
1978 output_byte_limit.map(|l| l as usize),
1979 terminal,
1980 language_registry,
1981 cx,
1982 )
1983 })
1984 }
1985 });
1986
1987 cx.spawn(async move |this, cx| {
1988 let terminal = terminal_task.await?;
1989 this.update(cx, |this, _cx| {
1990 this.terminals.insert(terminal_id, terminal.clone());
1991 terminal
1992 })
1993 })
1994 }
1995
1996 pub fn kill_terminal(
1997 &mut self,
1998 terminal_id: acp::TerminalId,
1999 cx: &mut Context<Self>,
2000 ) -> Result<()> {
2001 self.terminals
2002 .get(&terminal_id)
2003 .context("Terminal not found")?
2004 .update(cx, |terminal, cx| {
2005 terminal.kill(cx);
2006 });
2007
2008 Ok(())
2009 }
2010
2011 pub fn release_terminal(
2012 &mut self,
2013 terminal_id: acp::TerminalId,
2014 cx: &mut Context<Self>,
2015 ) -> Result<()> {
2016 self.terminals
2017 .remove(&terminal_id)
2018 .context("Terminal not found")?
2019 .update(cx, |terminal, cx| {
2020 terminal.kill(cx);
2021 });
2022
2023 Ok(())
2024 }
2025
2026 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2027 self.terminals
2028 .get(&terminal_id)
2029 .context("Terminal not found")
2030 .cloned()
2031 }
2032
2033 pub fn to_markdown(&self, cx: &App) -> String {
2034 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2035 }
2036
2037 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2038 cx.emit(AcpThreadEvent::LoadError(error));
2039 }
2040}
2041
2042fn markdown_for_raw_output(
2043 raw_output: &serde_json::Value,
2044 language_registry: &Arc<LanguageRegistry>,
2045 cx: &mut App,
2046) -> Option<Entity<Markdown>> {
2047 match raw_output {
2048 serde_json::Value::Null => None,
2049 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2050 Markdown::new(
2051 value.to_string().into(),
2052 Some(language_registry.clone()),
2053 None,
2054 cx,
2055 )
2056 })),
2057 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2058 Markdown::new(
2059 value.to_string().into(),
2060 Some(language_registry.clone()),
2061 None,
2062 cx,
2063 )
2064 })),
2065 serde_json::Value::String(value) => Some(cx.new(|cx| {
2066 Markdown::new(
2067 value.clone().into(),
2068 Some(language_registry.clone()),
2069 None,
2070 cx,
2071 )
2072 })),
2073 value => Some(cx.new(|cx| {
2074 Markdown::new(
2075 format!("```json\n{}\n```", value).into(),
2076 Some(language_registry.clone()),
2077 None,
2078 cx,
2079 )
2080 })),
2081 }
2082}
2083
2084#[cfg(test)]
2085mod tests {
2086 use super::*;
2087 use anyhow::anyhow;
2088 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2089 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2090 use indoc::indoc;
2091 use project::{FakeFs, Fs};
2092 use rand::Rng as _;
2093 use serde_json::json;
2094 use settings::SettingsStore;
2095 use smol::stream::StreamExt as _;
2096 use std::{
2097 any::Any,
2098 cell::RefCell,
2099 path::Path,
2100 rc::Rc,
2101 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2102 time::Duration,
2103 };
2104 use util::path;
2105
2106 fn init_test(cx: &mut TestAppContext) {
2107 env_logger::try_init().ok();
2108 cx.update(|cx| {
2109 let settings_store = SettingsStore::test(cx);
2110 cx.set_global(settings_store);
2111 Project::init_settings(cx);
2112 language::init(cx);
2113 });
2114 }
2115
2116 #[gpui::test]
2117 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2118 init_test(cx);
2119
2120 let fs = FakeFs::new(cx.executor());
2121 let project = Project::test(fs, [], cx).await;
2122 let connection = Rc::new(FakeAgentConnection::new());
2123 let thread = cx
2124 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2125 .await
2126 .unwrap();
2127
2128 // Test creating a new user message
2129 thread.update(cx, |thread, cx| {
2130 thread.push_user_content_block(
2131 None,
2132 acp::ContentBlock::Text(acp::TextContent {
2133 annotations: None,
2134 text: "Hello, ".to_string(),
2135 }),
2136 cx,
2137 );
2138 });
2139
2140 thread.update(cx, |thread, cx| {
2141 assert_eq!(thread.entries.len(), 1);
2142 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2143 assert_eq!(user_msg.id, None);
2144 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2145 } else {
2146 panic!("Expected UserMessage");
2147 }
2148 });
2149
2150 // Test appending to existing user message
2151 let message_1_id = UserMessageId::new();
2152 thread.update(cx, |thread, cx| {
2153 thread.push_user_content_block(
2154 Some(message_1_id.clone()),
2155 acp::ContentBlock::Text(acp::TextContent {
2156 annotations: None,
2157 text: "world!".to_string(),
2158 }),
2159 cx,
2160 );
2161 });
2162
2163 thread.update(cx, |thread, cx| {
2164 assert_eq!(thread.entries.len(), 1);
2165 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2166 assert_eq!(user_msg.id, Some(message_1_id));
2167 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2168 } else {
2169 panic!("Expected UserMessage");
2170 }
2171 });
2172
2173 // Test creating new user message after assistant message
2174 thread.update(cx, |thread, cx| {
2175 thread.push_assistant_content_block(
2176 acp::ContentBlock::Text(acp::TextContent {
2177 annotations: None,
2178 text: "Assistant response".to_string(),
2179 }),
2180 false,
2181 cx,
2182 );
2183 });
2184
2185 let message_2_id = UserMessageId::new();
2186 thread.update(cx, |thread, cx| {
2187 thread.push_user_content_block(
2188 Some(message_2_id.clone()),
2189 acp::ContentBlock::Text(acp::TextContent {
2190 annotations: None,
2191 text: "New user message".to_string(),
2192 }),
2193 cx,
2194 );
2195 });
2196
2197 thread.update(cx, |thread, cx| {
2198 assert_eq!(thread.entries.len(), 3);
2199 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2200 assert_eq!(user_msg.id, Some(message_2_id));
2201 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2202 } else {
2203 panic!("Expected UserMessage at index 2");
2204 }
2205 });
2206 }
2207
2208 #[gpui::test]
2209 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2210 init_test(cx);
2211
2212 let fs = FakeFs::new(cx.executor());
2213 let project = Project::test(fs, [], cx).await;
2214 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2215 |_, thread, mut cx| {
2216 async move {
2217 thread.update(&mut cx, |thread, cx| {
2218 thread
2219 .handle_session_update(
2220 acp::SessionUpdate::AgentThoughtChunk {
2221 content: "Thinking ".into(),
2222 },
2223 cx,
2224 )
2225 .unwrap();
2226 thread
2227 .handle_session_update(
2228 acp::SessionUpdate::AgentThoughtChunk {
2229 content: "hard!".into(),
2230 },
2231 cx,
2232 )
2233 .unwrap();
2234 })?;
2235 Ok(acp::PromptResponse {
2236 stop_reason: acp::StopReason::EndTurn,
2237 })
2238 }
2239 .boxed_local()
2240 },
2241 ));
2242
2243 let thread = cx
2244 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2245 .await
2246 .unwrap();
2247
2248 thread
2249 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2250 .await
2251 .unwrap();
2252
2253 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2254 assert_eq!(
2255 output,
2256 indoc! {r#"
2257 ## User
2258
2259 Hello from Zed!
2260
2261 ## Assistant
2262
2263 <thinking>
2264 Thinking hard!
2265 </thinking>
2266
2267 "#}
2268 );
2269 }
2270
2271 #[gpui::test]
2272 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2273 init_test(cx);
2274
2275 let fs = FakeFs::new(cx.executor());
2276 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2277 .await;
2278 let project = Project::test(fs.clone(), [], cx).await;
2279 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2280 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2281 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2282 move |_, thread, mut cx| {
2283 let read_file_tx = read_file_tx.clone();
2284 async move {
2285 let content = thread
2286 .update(&mut cx, |thread, cx| {
2287 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2288 })
2289 .unwrap()
2290 .await
2291 .unwrap();
2292 assert_eq!(content, "one\ntwo\nthree\n");
2293 read_file_tx.take().unwrap().send(()).unwrap();
2294 thread
2295 .update(&mut cx, |thread, cx| {
2296 thread.write_text_file(
2297 path!("/tmp/foo").into(),
2298 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2299 cx,
2300 )
2301 })
2302 .unwrap()
2303 .await
2304 .unwrap();
2305 Ok(acp::PromptResponse {
2306 stop_reason: acp::StopReason::EndTurn,
2307 })
2308 }
2309 .boxed_local()
2310 },
2311 ));
2312
2313 let (worktree, pathbuf) = project
2314 .update(cx, |project, cx| {
2315 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2316 })
2317 .await
2318 .unwrap();
2319 let buffer = project
2320 .update(cx, |project, cx| {
2321 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2322 })
2323 .await
2324 .unwrap();
2325
2326 let thread = cx
2327 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2328 .await
2329 .unwrap();
2330
2331 let request = thread.update(cx, |thread, cx| {
2332 thread.send_raw("Extend the count in /tmp/foo", cx)
2333 });
2334 read_file_rx.await.ok();
2335 buffer.update(cx, |buffer, cx| {
2336 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2337 });
2338 cx.run_until_parked();
2339 assert_eq!(
2340 buffer.read_with(cx, |buffer, _| buffer.text()),
2341 "zero\none\ntwo\nthree\nfour\nfive\n"
2342 );
2343 assert_eq!(
2344 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2345 "zero\none\ntwo\nthree\nfour\nfive\n"
2346 );
2347 request.await.unwrap();
2348 }
2349
2350 #[gpui::test]
2351 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2352 init_test(cx);
2353
2354 let fs = FakeFs::new(cx.executor());
2355 let project = Project::test(fs, [], cx).await;
2356 let id = acp::ToolCallId("test".into());
2357
2358 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2359 let id = id.clone();
2360 move |_, thread, mut cx| {
2361 let id = id.clone();
2362 async move {
2363 thread
2364 .update(&mut cx, |thread, cx| {
2365 thread.handle_session_update(
2366 acp::SessionUpdate::ToolCall(acp::ToolCall {
2367 id: id.clone(),
2368 title: "Label".into(),
2369 kind: acp::ToolKind::Fetch,
2370 status: acp::ToolCallStatus::InProgress,
2371 content: vec![],
2372 locations: vec![],
2373 raw_input: None,
2374 raw_output: None,
2375 }),
2376 cx,
2377 )
2378 })
2379 .unwrap()
2380 .unwrap();
2381 Ok(acp::PromptResponse {
2382 stop_reason: acp::StopReason::EndTurn,
2383 })
2384 }
2385 .boxed_local()
2386 }
2387 }));
2388
2389 let thread = cx
2390 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2391 .await
2392 .unwrap();
2393
2394 let request = thread.update(cx, |thread, cx| {
2395 thread.send_raw("Fetch https://example.com", cx)
2396 });
2397
2398 run_until_first_tool_call(&thread, cx).await;
2399
2400 thread.read_with(cx, |thread, _| {
2401 assert!(matches!(
2402 thread.entries[1],
2403 AgentThreadEntry::ToolCall(ToolCall {
2404 status: ToolCallStatus::InProgress,
2405 ..
2406 })
2407 ));
2408 });
2409
2410 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2411
2412 thread.read_with(cx, |thread, _| {
2413 assert!(matches!(
2414 &thread.entries[1],
2415 AgentThreadEntry::ToolCall(ToolCall {
2416 status: ToolCallStatus::Canceled,
2417 ..
2418 })
2419 ));
2420 });
2421
2422 thread
2423 .update(cx, |thread, cx| {
2424 thread.handle_session_update(
2425 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2426 id,
2427 fields: acp::ToolCallUpdateFields {
2428 status: Some(acp::ToolCallStatus::Completed),
2429 ..Default::default()
2430 },
2431 }),
2432 cx,
2433 )
2434 })
2435 .unwrap();
2436
2437 request.await.unwrap();
2438
2439 thread.read_with(cx, |thread, _| {
2440 assert!(matches!(
2441 thread.entries[1],
2442 AgentThreadEntry::ToolCall(ToolCall {
2443 status: ToolCallStatus::Completed,
2444 ..
2445 })
2446 ));
2447 });
2448 }
2449
2450 #[gpui::test]
2451 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2452 init_test(cx);
2453 let fs = FakeFs::new(cx.background_executor.clone());
2454 fs.insert_tree(path!("/test"), json!({})).await;
2455 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2456
2457 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2458 move |_, thread, mut cx| {
2459 async move {
2460 thread
2461 .update(&mut cx, |thread, cx| {
2462 thread.handle_session_update(
2463 acp::SessionUpdate::ToolCall(acp::ToolCall {
2464 id: acp::ToolCallId("test".into()),
2465 title: "Label".into(),
2466 kind: acp::ToolKind::Edit,
2467 status: acp::ToolCallStatus::Completed,
2468 content: vec![acp::ToolCallContent::Diff {
2469 diff: acp::Diff {
2470 path: "/test/test.txt".into(),
2471 old_text: None,
2472 new_text: "foo".into(),
2473 },
2474 }],
2475 locations: vec![],
2476 raw_input: None,
2477 raw_output: None,
2478 }),
2479 cx,
2480 )
2481 })
2482 .unwrap()
2483 .unwrap();
2484 Ok(acp::PromptResponse {
2485 stop_reason: acp::StopReason::EndTurn,
2486 })
2487 }
2488 .boxed_local()
2489 }
2490 }));
2491
2492 let thread = cx
2493 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2494 .await
2495 .unwrap();
2496
2497 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2498 .await
2499 .unwrap();
2500
2501 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2502 }
2503
2504 #[gpui::test(iterations = 10)]
2505 async fn test_checkpoints(cx: &mut TestAppContext) {
2506 init_test(cx);
2507 let fs = FakeFs::new(cx.background_executor.clone());
2508 fs.insert_tree(
2509 path!("/test"),
2510 json!({
2511 ".git": {}
2512 }),
2513 )
2514 .await;
2515 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2516
2517 let simulate_changes = Arc::new(AtomicBool::new(true));
2518 let next_filename = Arc::new(AtomicUsize::new(0));
2519 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2520 let simulate_changes = simulate_changes.clone();
2521 let next_filename = next_filename.clone();
2522 let fs = fs.clone();
2523 move |request, thread, mut cx| {
2524 let fs = fs.clone();
2525 let simulate_changes = simulate_changes.clone();
2526 let next_filename = next_filename.clone();
2527 async move {
2528 if simulate_changes.load(SeqCst) {
2529 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2530 fs.write(Path::new(&filename), b"").await?;
2531 }
2532
2533 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2534 panic!("expected text content block");
2535 };
2536 thread.update(&mut cx, |thread, cx| {
2537 thread
2538 .handle_session_update(
2539 acp::SessionUpdate::AgentMessageChunk {
2540 content: content.text.to_uppercase().into(),
2541 },
2542 cx,
2543 )
2544 .unwrap();
2545 })?;
2546 Ok(acp::PromptResponse {
2547 stop_reason: acp::StopReason::EndTurn,
2548 })
2549 }
2550 .boxed_local()
2551 }
2552 }));
2553 let thread = cx
2554 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2555 .await
2556 .unwrap();
2557
2558 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2559 .await
2560 .unwrap();
2561 thread.read_with(cx, |thread, cx| {
2562 assert_eq!(
2563 thread.to_markdown(cx),
2564 indoc! {"
2565 ## User (checkpoint)
2566
2567 Lorem
2568
2569 ## Assistant
2570
2571 LOREM
2572
2573 "}
2574 );
2575 });
2576 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2577
2578 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2579 .await
2580 .unwrap();
2581 thread.read_with(cx, |thread, cx| {
2582 assert_eq!(
2583 thread.to_markdown(cx),
2584 indoc! {"
2585 ## User (checkpoint)
2586
2587 Lorem
2588
2589 ## Assistant
2590
2591 LOREM
2592
2593 ## User (checkpoint)
2594
2595 ipsum
2596
2597 ## Assistant
2598
2599 IPSUM
2600
2601 "}
2602 );
2603 });
2604 assert_eq!(
2605 fs.files(),
2606 vec![
2607 Path::new(path!("/test/file-0")),
2608 Path::new(path!("/test/file-1"))
2609 ]
2610 );
2611
2612 // Checkpoint isn't stored when there are no changes.
2613 simulate_changes.store(false, SeqCst);
2614 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2615 .await
2616 .unwrap();
2617 thread.read_with(cx, |thread, cx| {
2618 assert_eq!(
2619 thread.to_markdown(cx),
2620 indoc! {"
2621 ## User (checkpoint)
2622
2623 Lorem
2624
2625 ## Assistant
2626
2627 LOREM
2628
2629 ## User (checkpoint)
2630
2631 ipsum
2632
2633 ## Assistant
2634
2635 IPSUM
2636
2637 ## User
2638
2639 dolor
2640
2641 ## Assistant
2642
2643 DOLOR
2644
2645 "}
2646 );
2647 });
2648 assert_eq!(
2649 fs.files(),
2650 vec![
2651 Path::new(path!("/test/file-0")),
2652 Path::new(path!("/test/file-1"))
2653 ]
2654 );
2655
2656 // Rewinding the conversation truncates the history and restores the checkpoint.
2657 thread
2658 .update(cx, |thread, cx| {
2659 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2660 panic!("unexpected entries {:?}", thread.entries)
2661 };
2662 thread.rewind(message.id.clone().unwrap(), cx)
2663 })
2664 .await
2665 .unwrap();
2666 thread.read_with(cx, |thread, cx| {
2667 assert_eq!(
2668 thread.to_markdown(cx),
2669 indoc! {"
2670 ## User (checkpoint)
2671
2672 Lorem
2673
2674 ## Assistant
2675
2676 LOREM
2677
2678 "}
2679 );
2680 });
2681 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2682 }
2683
2684 #[gpui::test]
2685 async fn test_refusal(cx: &mut TestAppContext) {
2686 init_test(cx);
2687 let fs = FakeFs::new(cx.background_executor.clone());
2688 fs.insert_tree(path!("/"), json!({})).await;
2689 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2690
2691 let refuse_next = Arc::new(AtomicBool::new(false));
2692 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2693 let refuse_next = refuse_next.clone();
2694 move |request, thread, mut cx| {
2695 let refuse_next = refuse_next.clone();
2696 async move {
2697 if refuse_next.load(SeqCst) {
2698 return Ok(acp::PromptResponse {
2699 stop_reason: acp::StopReason::Refusal,
2700 });
2701 }
2702
2703 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2704 panic!("expected text content block");
2705 };
2706 thread.update(&mut cx, |thread, cx| {
2707 thread
2708 .handle_session_update(
2709 acp::SessionUpdate::AgentMessageChunk {
2710 content: content.text.to_uppercase().into(),
2711 },
2712 cx,
2713 )
2714 .unwrap();
2715 })?;
2716 Ok(acp::PromptResponse {
2717 stop_reason: acp::StopReason::EndTurn,
2718 })
2719 }
2720 .boxed_local()
2721 }
2722 }));
2723 let thread = cx
2724 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2725 .await
2726 .unwrap();
2727
2728 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2729 .await
2730 .unwrap();
2731 thread.read_with(cx, |thread, cx| {
2732 assert_eq!(
2733 thread.to_markdown(cx),
2734 indoc! {"
2735 ## User
2736
2737 hello
2738
2739 ## Assistant
2740
2741 HELLO
2742
2743 "}
2744 );
2745 });
2746
2747 // Simulate refusing the second message, ensuring the conversation gets
2748 // truncated to before sending it.
2749 refuse_next.store(true, SeqCst);
2750 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2751 .await
2752 .unwrap();
2753 thread.read_with(cx, |thread, cx| {
2754 assert_eq!(
2755 thread.to_markdown(cx),
2756 indoc! {"
2757 ## User
2758
2759 hello
2760
2761 ## Assistant
2762
2763 HELLO
2764
2765 "}
2766 );
2767 });
2768 }
2769
2770 async fn run_until_first_tool_call(
2771 thread: &Entity<AcpThread>,
2772 cx: &mut TestAppContext,
2773 ) -> usize {
2774 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2775
2776 let subscription = cx.update(|cx| {
2777 cx.subscribe(thread, move |thread, _, cx| {
2778 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2779 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2780 return tx.try_send(ix).unwrap();
2781 }
2782 }
2783 })
2784 });
2785
2786 select! {
2787 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2788 panic!("Timeout waiting for tool call")
2789 }
2790 ix = rx.next().fuse() => {
2791 drop(subscription);
2792 ix.unwrap()
2793 }
2794 }
2795 }
2796
2797 #[derive(Clone, Default)]
2798 struct FakeAgentConnection {
2799 auth_methods: Vec<acp::AuthMethod>,
2800 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2801 on_user_message: Option<
2802 Rc<
2803 dyn Fn(
2804 acp::PromptRequest,
2805 WeakEntity<AcpThread>,
2806 AsyncApp,
2807 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2808 + 'static,
2809 >,
2810 >,
2811 }
2812
2813 impl FakeAgentConnection {
2814 fn new() -> Self {
2815 Self {
2816 auth_methods: Vec::new(),
2817 on_user_message: None,
2818 sessions: Arc::default(),
2819 }
2820 }
2821
2822 #[expect(unused)]
2823 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2824 self.auth_methods = auth_methods;
2825 self
2826 }
2827
2828 fn on_user_message(
2829 mut self,
2830 handler: impl Fn(
2831 acp::PromptRequest,
2832 WeakEntity<AcpThread>,
2833 AsyncApp,
2834 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2835 + 'static,
2836 ) -> Self {
2837 self.on_user_message.replace(Rc::new(handler));
2838 self
2839 }
2840 }
2841
2842 impl AgentConnection for FakeAgentConnection {
2843 fn auth_methods(&self) -> &[acp::AuthMethod] {
2844 &self.auth_methods
2845 }
2846
2847 fn new_thread(
2848 self: Rc<Self>,
2849 project: Entity<Project>,
2850 _cwd: &Path,
2851 cx: &mut App,
2852 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2853 let session_id = acp::SessionId(
2854 rand::thread_rng()
2855 .sample_iter(&rand::distributions::Alphanumeric)
2856 .take(7)
2857 .map(char::from)
2858 .collect::<String>()
2859 .into(),
2860 );
2861 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2862 let thread = cx.new(|cx| {
2863 AcpThread::new(
2864 "Test",
2865 self.clone(),
2866 project,
2867 action_log,
2868 session_id.clone(),
2869 watch::Receiver::constant(acp::PromptCapabilities {
2870 image: true,
2871 audio: true,
2872 embedded_context: true,
2873 }),
2874 vec![],
2875 cx,
2876 )
2877 });
2878 self.sessions.lock().insert(session_id, thread.downgrade());
2879 Task::ready(Ok(thread))
2880 }
2881
2882 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2883 if self.auth_methods().iter().any(|m| m.id == method) {
2884 Task::ready(Ok(()))
2885 } else {
2886 Task::ready(Err(anyhow!("Invalid Auth Method")))
2887 }
2888 }
2889
2890 fn prompt(
2891 &self,
2892 _id: Option<UserMessageId>,
2893 params: acp::PromptRequest,
2894 cx: &mut App,
2895 ) -> Task<gpui::Result<acp::PromptResponse>> {
2896 let sessions = self.sessions.lock();
2897 let thread = sessions.get(¶ms.session_id).unwrap();
2898 if let Some(handler) = &self.on_user_message {
2899 let handler = handler.clone();
2900 let thread = thread.clone();
2901 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2902 } else {
2903 Task::ready(Ok(acp::PromptResponse {
2904 stop_reason: acp::StopReason::EndTurn,
2905 }))
2906 }
2907 }
2908
2909 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2910 let sessions = self.sessions.lock();
2911 let thread = sessions.get(session_id).unwrap().clone();
2912
2913 cx.spawn(async move |cx| {
2914 thread
2915 .update(cx, |thread, cx| thread.cancel(cx))
2916 .unwrap()
2917 .await
2918 })
2919 .detach();
2920 }
2921
2922 fn truncate(
2923 &self,
2924 session_id: &acp::SessionId,
2925 _cx: &App,
2926 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2927 Some(Rc::new(FakeAgentSessionEditor {
2928 _session_id: session_id.clone(),
2929 }))
2930 }
2931
2932 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2933 self
2934 }
2935 }
2936
2937 struct FakeAgentSessionEditor {
2938 _session_id: acp::SessionId,
2939 }
2940
2941 impl AgentSessionTruncate for FakeAgentSessionEditor {
2942 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2943 Task::ready(Ok(()))
2944 }
2945 }
2946}