1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9};
10use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
11use anyhow::{Result, anyhow};
12use assistant_tool::{
13 ActionLog, AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus,
14 ToolWorkingSet,
15};
16use chrono::{DateTime, Utc};
17use client::{ModelRequestUsage, RequestUsage};
18use collections::{HashMap, HashSet};
19use feature_flags::{self, FeatureFlagAppExt};
20use futures::{
21 FutureExt, StreamExt as _,
22 channel::{mpsc, oneshot},
23 future::{BoxFuture, Either, LocalBoxFuture, Shared},
24 stream::{BoxStream, LocalBoxStream},
25};
26use git::repository::DiffType;
27use gpui::{
28 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
29 WeakEntity, Window,
30};
31use icons::IconName;
32use language_model::{
33 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
34 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
35 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
36 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
37 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
38 TokenUsage,
39};
40use postage::stream::Stream as _;
41use project::{
42 Project,
43 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
44};
45use prompt_store::{ModelContext, PromptBuilder};
46use proto::Plan;
47use schemars::JsonSchema;
48use serde::{Deserialize, Serialize};
49use settings::Settings;
50use std::{collections::VecDeque, fmt::Write};
51use std::{
52 ops::Range,
53 sync::Arc,
54 time::{Duration, Instant},
55};
56use thiserror::Error;
57use util::{ResultExt as _, debug_panic, post_inc, truncate_lines_to_byte_limit};
58use uuid::Uuid;
59use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
60
61const MAX_RETRY_ATTEMPTS: u8 = 3;
62const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
63
64#[derive(
65 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
66)]
67pub struct ThreadId(Arc<str>);
68
69impl ThreadId {
70 pub fn new() -> Self {
71 Self(Uuid::new_v4().to_string().into())
72 }
73}
74
75impl std::fmt::Display for ThreadId {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{}", self.0)
78 }
79}
80
81impl From<&str> for ThreadId {
82 fn from(value: &str) -> Self {
83 Self(value.into())
84 }
85}
86
87/// The ID of the user prompt that initiated a request.
88///
89/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
90#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
91pub struct PromptId(Arc<str>);
92
93impl PromptId {
94 pub fn new() -> Self {
95 Self(Uuid::new_v4().to_string().into())
96 }
97}
98
99impl std::fmt::Display for PromptId {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(f, "{}", self.0)
102 }
103}
104
105#[derive(Debug)]
106pub struct ToolUse {
107 pub id: LanguageModelToolUseId,
108 pub name: SharedString,
109 pub ui_text: SharedString,
110 pub status: ToolUseStatus,
111 pub input: serde_json::Value,
112 pub icon: icons::IconName,
113 pub needs_confirmation: bool,
114}
115
116#[derive(Debug, Clone)]
117pub struct PendingToolUse {
118 pub id: LanguageModelToolUseId,
119 /// The ID of the Assistant message in which the tool use was requested.
120 #[allow(unused)]
121 pub assistant_message_id: MessageId,
122 pub name: Arc<str>,
123 pub ui_text: Arc<str>,
124 pub input: serde_json::Value,
125 pub status: PendingToolUseStatus,
126 pub may_perform_edits: bool,
127}
128
129#[derive(Debug, Clone)]
130pub struct Confirmation {
131 pub tool_use_id: LanguageModelToolUseId,
132 pub input: serde_json::Value,
133 pub ui_text: Arc<str>,
134 pub request: Arc<LanguageModelRequest>,
135 pub tool: Arc<dyn Tool>,
136}
137
138#[derive(Debug, Clone)]
139pub enum PendingToolUseStatus {
140 InputStillStreaming,
141 Idle,
142 NeedsConfirmation(Arc<Confirmation>),
143 Running { _task: Shared<Task<()>> },
144 Error(#[allow(unused)] Arc<str>),
145}
146
147impl PendingToolUseStatus {
148 pub fn is_idle(&self) -> bool {
149 matches!(self, PendingToolUseStatus::Idle)
150 }
151
152 pub fn is_error(&self) -> bool {
153 matches!(self, PendingToolUseStatus::Error(_))
154 }
155
156 pub fn needs_confirmation(&self) -> bool {
157 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
158 }
159}
160
161#[derive(Clone)]
162pub struct ToolUseMetadata {
163 pub model: Arc<dyn LanguageModel>,
164 pub thread_id: ThreadId,
165 pub prompt_id: PromptId,
166}
167
168#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
169pub struct MessageId(pub(crate) usize);
170
171impl MessageId {
172 fn post_inc(&mut self) -> Self {
173 Self(post_inc(&mut self.0))
174 }
175
176 pub fn as_usize(&self) -> usize {
177 self.0
178 }
179}
180
181/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
182#[derive(Clone, Debug)]
183pub struct MessageCrease {
184 pub range: Range<usize>,
185 pub icon_path: SharedString,
186 pub label: SharedString,
187 /// None for a deserialized message, Some otherwise.
188 pub context: Option<AgentContextHandle>,
189}
190
191/// A message in a [`Thread`].
192#[derive(Debug, Clone)]
193pub struct Message {
194 pub id: MessageId,
195 pub role: Role,
196 pub segments: Vec<MessageSegment>,
197 pub loaded_context: LoadedContext,
198 pub creases: Vec<MessageCrease>,
199 pub ui_only: bool,
200}
201
202impl Message {
203 /// Returns whether the message contains any meaningful text that should be displayed
204 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
205 pub fn should_display_content(&self) -> bool {
206 self.segments.iter().all(|segment| segment.should_display())
207 }
208
209 pub fn push(&mut self, segment: MessageSegment) {
210 match segment {
211 MessageSegment::Text(text) => self.push_text(text),
212 MessageSegment::Thinking { text, signature } => self.push_thinking(text, signature),
213 MessageSegment::ToolUse(segment) => {
214 self.segments.push(MessageSegment::ToolUse(segment))
215 }
216 }
217 }
218
219 pub fn push_thinking(&mut self, text: String, signature: Option<String>) {
220 if let Some(MessageSegment::Thinking {
221 text: segment,
222 signature: current_signature,
223 }) = self.segments.last_mut()
224 {
225 if let Some(signature) = signature {
226 *current_signature = Some(signature);
227 }
228 segment.push_str(&text);
229 } else {
230 self.segments
231 .push(MessageSegment::Thinking { text, signature });
232 }
233 }
234
235 pub fn push_text(&mut self, text: String) {
236 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
237 segment.push_str(&text);
238 } else {
239 self.segments.push(MessageSegment::Text(text));
240 }
241 }
242
243 pub fn to_string(&self) -> String {
244 let mut result = String::new();
245
246 if !self.loaded_context.text.is_empty() {
247 result.push_str(&self.loaded_context.text);
248 }
249
250 for segment in &self.segments {
251 match segment {
252 MessageSegment::Text(text) => result.push_str(text),
253 MessageSegment::Thinking { text, .. } => {
254 result.push_str("<think>\n");
255 result.push_str(text);
256 result.push_str("\n</think>");
257 }
258 MessageSegment::ToolUse(ToolUseSegment { name, input, .. }) => {
259 writeln!(&mut result, "<tool_use name=\"{}\">\n", name).ok();
260 result.push_str(
261 &serde_json::to_string_pretty(input)
262 .unwrap_or("<failed to serialize input>".into()),
263 );
264 result.push_str("\n</tool_use>");
265 }
266 }
267 }
268
269 result
270 }
271}
272
273#[derive(Debug, Clone)]
274#[cfg_attr(test, derive(PartialEq))]
275pub enum MessageSegment {
276 Text(String),
277 Thinking {
278 text: String,
279 signature: Option<String>,
280 },
281 ToolUse(ToolUseSegment),
282}
283
284#[derive(Debug, Clone)]
285pub struct ToolUseSegment {
286 pub name: Arc<str>,
287 pub input: serde_json::Value,
288 pub card: Option<AnyToolCard>,
289 pub output: Option<Result<LanguageModelToolResultContent, Arc<anyhow::Error>>>,
290 pub status: ToolUseStatus,
291}
292
293#[cfg(test)]
294impl PartialEq for ToolUseSegment {
295 fn eq(&self, other: &Self) -> bool {
296 self.name == other.name
297 && self.input == other.input
298 && self.card == other.card
299 && self
300 .output
301 .as_ref()
302 .map(|r| r.as_ref().map_err(|err| err.to_string()))
303 == other
304 .output
305 .as_ref()
306 .map(|r| r.as_ref().map_err(|err| err.to_string()))
307 }
308}
309
310impl MessageSegment {
311 pub fn should_display(&self) -> bool {
312 match self {
313 Self::Text(text) => text.is_empty(),
314 Self::Thinking { text, .. } => text.is_empty(),
315 Self::ToolUse { .. } => true,
316 }
317 }
318
319 pub fn text(&self) -> Option<&str> {
320 match self {
321 MessageSegment::Text(text) => Some(text),
322 _ => None,
323 }
324 }
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
328pub struct ProjectSnapshot {
329 pub worktree_snapshots: Vec<WorktreeSnapshot>,
330 pub unsaved_buffer_paths: Vec<String>,
331 pub timestamp: DateTime<Utc>,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
335pub struct WorktreeSnapshot {
336 pub worktree_path: String,
337 pub git_state: Option<GitState>,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
341pub struct GitState {
342 pub remote_url: Option<String>,
343 pub head_sha: Option<String>,
344 pub current_branch: Option<String>,
345 pub diff: Option<String>,
346}
347
348#[derive(Clone, Debug)]
349pub struct ThreadCheckpoint {
350 message_id: MessageId,
351 git_checkpoint: GitStoreCheckpoint,
352}
353
354#[derive(Copy, Clone, Debug, PartialEq, Eq)]
355pub enum ThreadFeedback {
356 Positive,
357 Negative,
358}
359
360pub enum LastRestoreCheckpoint {
361 Pending {
362 message_id: MessageId,
363 },
364 Error {
365 message_id: MessageId,
366 error: String,
367 },
368}
369
370impl LastRestoreCheckpoint {
371 pub fn message_id(&self) -> MessageId {
372 match self {
373 LastRestoreCheckpoint::Pending { message_id } => *message_id,
374 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
375 }
376 }
377}
378
379#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
380pub enum DetailedSummaryState {
381 #[default]
382 NotGenerated,
383 Generating {
384 message_id: MessageId,
385 },
386 Generated {
387 text: SharedString,
388 message_id: MessageId,
389 },
390}
391
392impl DetailedSummaryState {
393 fn text(&self) -> Option<SharedString> {
394 if let Self::Generated { text, .. } = self {
395 Some(text.clone())
396 } else {
397 None
398 }
399 }
400}
401
402#[derive(Default, Debug)]
403pub struct TotalTokenUsage {
404 pub total: u64,
405 pub max: u64,
406}
407
408impl TotalTokenUsage {
409 pub fn ratio(&self) -> TokenUsageRatio {
410 #[cfg(debug_assertions)]
411 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
412 .unwrap_or("0.8".to_string())
413 .parse()
414 .unwrap();
415 #[cfg(not(debug_assertions))]
416 let warning_threshold: f32 = 0.8;
417
418 // When the maximum is unknown because there is no selected model,
419 // avoid showing the token limit warning.
420 if self.max == 0 {
421 TokenUsageRatio::Normal
422 } else if self.total >= self.max {
423 TokenUsageRatio::Exceeded
424 } else if self.total as f32 / self.max as f32 >= warning_threshold {
425 TokenUsageRatio::Warning
426 } else {
427 TokenUsageRatio::Normal
428 }
429 }
430
431 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
432 TotalTokenUsage {
433 total: self.total + tokens,
434 max: self.max,
435 }
436 }
437}
438
439#[derive(Debug, Default, PartialEq, Eq)]
440pub enum TokenUsageRatio {
441 #[default]
442 Normal,
443 Warning,
444 Exceeded,
445}
446
447#[derive(Debug, Clone, Copy)]
448pub enum QueueState {
449 Sending,
450 Queued { position: usize },
451 Started,
452}
453
454struct PendingTurn {
455 task: Task<Result<()>>, // todo!("get rid of error")
456 cancel_tx: oneshot::Sender<()>,
457}
458
459struct PendingToolUse2 {
460 index_in_message: usize,
461 request: LanguageModelToolUse,
462 output: Option<Task<Result<ToolResultOutput>>>,
463}
464
465impl PendingToolUse2 {
466 async fn result(
467 &mut self,
468 ) -> (
469 LanguageModelToolResult,
470 Result<LanguageModelToolResultContent>,
471 ) {
472 let content;
473 let output;
474 let thread_result;
475 match self.output.take().unwrap().await {
476 Ok(tool_output) => {
477 content = match tool_output.content {
478 ToolResultContent::Text(text) => {
479 LanguageModelToolResultContent::Text(text.into())
480 }
481 ToolResultContent::Image(image) => LanguageModelToolResultContent::Image(image),
482 };
483 thread_result = Ok(content.clone());
484 output = tool_output.output;
485 }
486 Err(error) => {
487 content = LanguageModelToolResultContent::Text(error.to_string().into());
488 thread_result = Err(error);
489 output = None;
490 }
491 };
492
493 (
494 LanguageModelToolResult {
495 tool_use_id: self.request.id.clone(),
496 tool_name: self.request.name.clone(),
497 is_error: thread_result.is_err(),
498 content,
499 output,
500 },
501 thread_result,
502 )
503 }
504}
505
506/// A thread of conversation with the LLM.
507/// todo! Rename to ZedAgentThread (it will implement the AgentThread trait, along with externa)
508pub struct ZedAgentThread {
509 // Formerly Thread -> ZedAgent -> ZedAgentThread
510 id: ThreadId,
511 next_message_id: MessageId,
512 thread_messages: Vec<Message>, // todo!() remove this and just use .messages.
513 summary: ThreadSummary,
514 pending_checkpoint: Option<ThreadCheckpoint>,
515 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
516 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
517 action_log: Entity<ActionLog>,
518 updated_at: DateTime<Utc>,
519 last_received_chunk_at: Option<Instant>,
520
521 pending_summary: Task<Option<()>>,
522 detailed_summary_task: Task<Option<()>>,
523 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
524 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
525 completion_mode: agent_settings::CompletionMode,
526 last_prompt_id: PromptId,
527 project_context: SharedProjectContext,
528 completion_count: usize,
529 pending_completions: Vec<PendingCompletion>,
530 project: Entity<Project>,
531 prompt_builder: Arc<PromptBuilder>,
532 tools: Entity<ToolWorkingSet>,
533 messages: Vec<LanguageModelRequestMessage>,
534 thread_user_messages: HashMap<MessageId, usize>,
535 pending_turn: Option<PendingTurn>,
536
537 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
538 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
539 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
540 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
541 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
542
543 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
544 request_token_usage: Vec<TokenUsage>,
545 cumulative_token_usage: TokenUsage,
546 exceeded_window_error: Option<ExceededWindowError>,
547 tool_use_limit_reached: bool,
548 feedback: Option<ThreadFeedback>,
549 retry_state: Option<RetryState>,
550 message_feedback: HashMap<MessageId, ThreadFeedback>,
551 last_auto_capture_at: Option<Instant>,
552 request_callback: Option<
553 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
554 >,
555 remaining_turns: u32,
556 configured_model: Option<ConfiguredModel>,
557 profile: AgentProfile,
558}
559
560impl ZedAgentThread {
561 pub fn id(&self) -> &ThreadId {
562 &self.id
563 }
564
565 pub fn is_empty(&self) -> bool {
566 self.thread_messages.is_empty()
567 }
568
569 pub fn summary(&self) -> &ThreadSummary {
570 &self.summary
571 }
572
573 pub fn project(&self) -> &Entity<Project> {
574 &self.project
575 }
576
577 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
578 let current_summary = match &self.summary {
579 ThreadSummary::Pending | ThreadSummary::Generating => return,
580 ThreadSummary::Ready(summary) => summary,
581 ThreadSummary::Error => &ThreadSummary::DEFAULT,
582 };
583
584 let mut new_summary = new_summary.into();
585
586 if new_summary.is_empty() {
587 new_summary = ThreadSummary::DEFAULT;
588 }
589
590 if current_summary != &new_summary {
591 self.summary = ThreadSummary::Ready(new_summary);
592 cx.emit(ThreadEvent::SummaryChanged);
593 }
594 }
595
596 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> + DoubleEndedIterator {
597 self.thread_messages.iter()
598 }
599
600 pub fn insert_message(
601 &mut self,
602 role: Role,
603 segments: Vec<MessageSegment>,
604 loaded_context: LoadedContext,
605 creases: Vec<MessageCrease>,
606 cx: &mut Context<Self>,
607 ) -> MessageId {
608 let id = self.next_message_id.post_inc();
609 self.thread_messages.push(Message {
610 id,
611 role,
612 segments,
613 loaded_context,
614 creases,
615 ui_only: false,
616 });
617 self.touch_updated_at();
618 cx.emit(ThreadEvent::MessageAdded(id));
619 id
620 }
621
622 pub fn insert_assistant_message(
623 &mut self,
624 segments: Vec<MessageSegment>,
625 cx: &mut Context<Self>,
626 ) -> MessageId {
627 self.received_chunk();
628
629 self.insert_message(
630 Role::Assistant,
631 segments,
632 LoadedContext::default(),
633 Vec::new(),
634 cx,
635 )
636 }
637
638 pub fn push_assistant_message_segment(
639 &mut self,
640 segment: MessageSegment,
641 cx: &mut Context<Self>,
642 ) -> usize {
643 self.received_chunk();
644
645 if let Some(last_message) = self.thread_messages.last_mut() {
646 if last_message.role == Role::Assistant {
647 match &segment {
648 MessageSegment::Text(chunk) => {
649 cx.emit(ThreadEvent::ReceivedTextChunk);
650 cx.emit(ThreadEvent::StreamedAssistantText(
651 last_message.id,
652 chunk.clone(),
653 ));
654 }
655 MessageSegment::Thinking { text, .. } => {
656 cx.emit(ThreadEvent::StreamedAssistantThinking(
657 last_message.id,
658 text.clone(),
659 ));
660 }
661 MessageSegment::ToolUse(_) => {
662 cx.emit(ThreadEvent::StreamedToolUse2 {
663 message_id: last_message.id,
664 segment_index: last_message.segments.len(),
665 });
666 }
667 }
668 last_message.push(segment);
669 // todo! emit a new streamed segment event
670 return last_message.segments.len() - 1;
671 }
672 }
673
674 self.insert_message(
675 Role::Assistant,
676 vec![segment],
677 LoadedContext::default(),
678 Vec::new(),
679 cx,
680 );
681 0
682 }
683
684 pub fn push_tool_call(
685 &mut self,
686 name: Arc<str>,
687 input: serde_json::Value,
688 card: Option<AnyToolCard>,
689 cx: &mut Context<Self>,
690 ) -> usize {
691 self.push_assistant_message_segment(
692 MessageSegment::ToolUse(ToolUseSegment {
693 name,
694 input,
695 card,
696 output: None,
697 status: ToolUseStatus::Pending,
698 }),
699 cx,
700 )
701 }
702
703 pub fn set_tool_call_result(
704 &mut self,
705 segment_index: usize,
706 result: Result<LanguageModelToolResultContent>,
707 cx: &mut Context<Self>,
708 ) {
709 if let Some(last_message) = self.thread_messages.last_mut() {
710 if last_message.role == Role::Assistant {
711 if let Some(MessageSegment::ToolUse(ToolUseSegment { output, status, .. })) =
712 last_message.segments.get_mut(segment_index)
713 {
714 cx.emit(ThreadEvent::StreamedToolUse2 {
715 message_id: last_message.id,
716 segment_index,
717 });
718
719 *status = match &result {
720 Ok(content) => ToolUseStatus::Finished(
721 content
722 .to_str()
723 .map(|str| str.to_owned().into())
724 .unwrap_or_default(),
725 ),
726 Err(err) => ToolUseStatus::Error(err.to_string().into()),
727 };
728 *output = Some(result.map_err(Arc::new));
729 return;
730 } else {
731 debug_panic!("invalid segment index");
732 }
733 }
734 };
735 // todo! emit segment update event
736
737 debug_panic!("Expected last message's role assistant")
738 }
739
740 pub fn insert_retry_message(&mut self, retry_message: String, cx: &mut Context<Self>) {
741 // Add a UI-only message instead of a regular message
742 let id = self.next_message_id.post_inc();
743 self.thread_messages.push(Message {
744 id,
745 role: Role::System,
746 segments: vec![MessageSegment::Text(retry_message)],
747 loaded_context: LoadedContext::default(),
748 creases: Vec::new(),
749 ui_only: true,
750 });
751 self.touch_updated_at();
752 cx.emit(ThreadEvent::MessageAdded(id));
753 }
754
755 pub fn edit_message(
756 &mut self,
757 id: MessageId,
758 new_role: Role,
759 new_segments: Vec<MessageSegment>,
760 creases: Vec<MessageCrease>,
761 loaded_context: Option<LoadedContext>,
762 checkpoint: Option<GitStoreCheckpoint>,
763 cx: &mut Context<Self>,
764 ) -> bool {
765 let Some(message) = self
766 .thread_messages
767 .iter_mut()
768 .find(|message| message.id == id)
769 else {
770 return false;
771 };
772 message.role = new_role;
773 message.segments = new_segments;
774 message.creases = creases;
775 if let Some(context) = loaded_context {
776 message.loaded_context = context;
777 }
778 if let Some(git_checkpoint) = checkpoint {
779 self.checkpoints_by_message.insert(
780 id,
781 ThreadCheckpoint {
782 message_id: id,
783 git_checkpoint,
784 },
785 );
786 }
787 self.touch_updated_at();
788 cx.emit(ThreadEvent::MessageEdited(id));
789 true
790 }
791
792 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
793 let Some(index) = self
794 .thread_messages
795 .iter()
796 .position(|message| message.id == id)
797 else {
798 return false;
799 };
800 self.thread_messages.remove(index);
801 self.touch_updated_at();
802 cx.emit(ThreadEvent::MessageDeleted(id));
803 true
804 }
805
806 /// Returns the representation of this [`Thread`] in a textual form.
807 ///
808 /// This is the representation we use when attaching a thread as context to another thread.
809 pub fn text(&self) -> String {
810 let mut text = String::new();
811
812 for message in &self.thread_messages {
813 text.push_str(match message.role {
814 language_model::Role::User => "User:",
815 language_model::Role::Assistant => "Agent:",
816 language_model::Role::System => "System:",
817 });
818 text.push('\n');
819
820 for segment in &message.segments {
821 match segment {
822 MessageSegment::Text(content) => text.push_str(content),
823 MessageSegment::Thinking { text: content, .. } => {
824 text.push_str(&format!("<think>{}</think>", content))
825 }
826 MessageSegment::ToolUse { .. } => {}
827 }
828 }
829 text.push('\n');
830 }
831
832 text
833 }
834
835 pub fn touch_updated_at(&mut self) {
836 self.updated_at = Utc::now();
837 }
838
839 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
840 self.thread_messages
841 .iter()
842 .find(|message| message.id == id)
843 .into_iter()
844 .flat_map(|message| message.loaded_context.contexts.iter())
845 }
846
847 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
848 self.checkpoints_by_message.get(&id).cloned()
849 }
850
851 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
852 self.last_restore_checkpoint.as_ref()
853 }
854
855 pub fn restore_checkpoint(
856 &mut self,
857 checkpoint: ThreadCheckpoint,
858 cx: &mut Context<Self>,
859 ) -> Task<Result<()>> {
860 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
861 message_id: checkpoint.message_id,
862 });
863 cx.emit(ThreadEvent::CheckpointChanged);
864 cx.notify();
865
866 let git_store = self.project.read(cx).git_store().clone();
867 let restore = git_store.update(cx, |git_store, cx| {
868 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
869 });
870
871 cx.spawn(async move |this, cx| {
872 let result = restore.await;
873 this.update(cx, |this, cx| {
874 if let Err(err) = result.as_ref() {
875 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
876 message_id: checkpoint.message_id,
877 error: err.to_string(),
878 });
879 } else {
880 this.truncate(checkpoint.message_id, cx);
881 this.last_restore_checkpoint = None;
882 }
883 this.pending_checkpoint = None;
884 cx.emit(ThreadEvent::CheckpointChanged);
885 cx.notify();
886 })?;
887 result
888 })
889 }
890
891 pub fn action_log(&self) -> Entity<ActionLog> {
892 self.action_log.clone()
893 }
894
895 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
896 self.checkpoints_by_message
897 .insert(checkpoint.message_id, checkpoint);
898 cx.emit(ThreadEvent::CheckpointChanged);
899 cx.notify();
900 }
901
902 pub fn message(&self, id: MessageId) -> Option<&Message> {
903 let index = self
904 .thread_messages
905 .binary_search_by(|message| message.id.cmp(&id))
906 .ok()?;
907
908 self.thread_messages.get(index)
909 }
910
911 /// Indicates whether streaming of language model events is stale.
912 pub fn is_generation_stale(&self) -> Option<bool> {
913 const STALE_THRESHOLD: u128 = 250;
914
915 self.last_received_chunk_at
916 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
917 }
918
919 fn received_chunk(&mut self) {
920 self.last_received_chunk_at = Some(Instant::now());
921 }
922}
923
924#[derive(Clone, Debug)]
925struct RetryState {
926 attempt: u8,
927 max_attempts: u8,
928 intent: CompletionIntent,
929}
930
931#[derive(Clone, Debug, PartialEq, Eq)]
932pub enum ThreadSummary {
933 Pending,
934 Generating,
935 Ready(SharedString),
936 Error,
937}
938
939impl ThreadSummary {
940 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
941
942 pub fn or_default(&self) -> SharedString {
943 self.unwrap_or(Self::DEFAULT)
944 }
945
946 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
947 self.ready().unwrap_or_else(|| message.into())
948 }
949
950 pub fn ready(&self) -> Option<SharedString> {
951 match self {
952 ThreadSummary::Ready(summary) => Some(summary.clone()),
953 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
954 }
955 }
956}
957
958#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
959pub struct ExceededWindowError {
960 /// Model used when last message exceeded context window
961 model_id: LanguageModelId,
962 /// Token count including last message
963 token_count: u64,
964}
965
966pub struct UserMessageParams {
967 pub text: String,
968 pub creases: Vec<MessageCrease>,
969 pub checkpoint: Option<GitStoreCheckpoint>,
970 pub context: ContextLoadResult,
971}
972
973impl<T: Into<String>> From<T> for UserMessageParams {
974 fn from(text: T) -> Self {
975 UserMessageParams {
976 text: text.into(),
977 creases: Vec::new(),
978 checkpoint: None,
979 context: ContextLoadResult::default(),
980 }
981 }
982}
983
984pub struct Turn {
985 user_message_id: MessageId,
986 response_events: LocalBoxStream<'static, Result<ResponseEvent>>,
987}
988
989struct ToolCallResult {
990 task: Task<Result<()>>,
991 card: Option<AnyToolCard>,
992}
993
994pub enum ResponseEvent {
995 Text(String),
996 Thinking(String),
997 ToolCallChunk {
998 id: LanguageModelToolUseId,
999 label: String,
1000 input: serde_json::Value,
1001 },
1002 ToolCall {
1003 id: LanguageModelToolUseId,
1004 needs_confirmation: bool,
1005 label: String,
1006 run: Box<dyn FnOnce(Option<AnyWindowHandle>, &mut App) -> ToolCallResult>,
1007 },
1008 InvalidToolCallChunk(LanguageModelToolUse),
1009}
1010
1011impl ZedAgentThread {
1012 pub fn new(
1013 project: Entity<Project>,
1014 tools: Entity<ToolWorkingSet>,
1015 prompt_builder: Arc<PromptBuilder>,
1016 system_prompt: SharedProjectContext,
1017 cx: &mut Context<Self>,
1018 ) -> Self {
1019 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
1020 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
1021 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
1022
1023 Self {
1024 summary: ThreadSummary::Pending,
1025 pending_checkpoint: None,
1026 next_message_id: MessageId(0),
1027 id: ThreadId::new(),
1028 thread_messages: Vec::new(),
1029 checkpoints_by_message: HashMap::default(),
1030 last_restore_checkpoint: None,
1031 updated_at: Utc::now(),
1032 action_log: cx.new(|_| ActionLog::new(project.clone())),
1033 last_received_chunk_at: None,
1034 pending_turn: None,
1035 messages: Vec::new(),
1036 pending_summary: Task::ready(None),
1037 detailed_summary_task: Task::ready(None),
1038 detailed_summary_tx,
1039 detailed_summary_rx,
1040 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
1041 last_prompt_id: PromptId::new(),
1042 project_context: system_prompt,
1043 completion_count: 0,
1044 pending_completions: Vec::new(),
1045 project: project.clone(),
1046 prompt_builder,
1047 tools: tools.clone(),
1048 tool_uses_by_assistant_message: HashMap::default(),
1049 tool_results: HashMap::default(),
1050 pending_tool_uses_by_id: HashMap::default(),
1051 tool_result_cards: HashMap::default(),
1052 tool_use_metadata_by_id: HashMap::default(),
1053 thread_user_messages: Default::default(),
1054 initial_project_snapshot: {
1055 let project_snapshot = Self::project_snapshot(project, cx);
1056 cx.foreground_executor()
1057 .spawn(async move { Some(project_snapshot.await) })
1058 .shared()
1059 },
1060 request_token_usage: Vec::new(),
1061 cumulative_token_usage: TokenUsage::default(),
1062 exceeded_window_error: None,
1063 tool_use_limit_reached: false,
1064 feedback: None,
1065 retry_state: None,
1066 message_feedback: HashMap::default(),
1067 last_auto_capture_at: None,
1068 request_callback: None,
1069 remaining_turns: u32::MAX,
1070 configured_model,
1071 profile: AgentProfile::new(profile_id, tools),
1072 }
1073 }
1074
1075 pub fn is_turn_end(&self, ix: usize) -> bool {
1076 if self.thread_messages.is_empty() {
1077 return false;
1078 }
1079
1080 if !self.is_generating() && ix == self.thread_messages.len() - 1 {
1081 return true;
1082 }
1083
1084 let Some(message) = self.thread_messages.get(ix) else {
1085 return false;
1086 };
1087
1088 if message.role != Role::Assistant {
1089 return false;
1090 }
1091
1092 self.thread_messages
1093 .get(ix + 1)
1094 .and_then(|message| {
1095 self.message(message.id)
1096 .map(|next_message| next_message.role == Role::User)
1097 })
1098 .unwrap_or(false)
1099 }
1100
1101 pub fn deserialize(
1102 id: ThreadId,
1103 serialized: SerializedThread,
1104 project: Entity<Project>,
1105 tools: Entity<ToolWorkingSet>,
1106 prompt_builder: Arc<PromptBuilder>,
1107 project_context: SharedProjectContext,
1108 window: Option<&mut Window>, // None in headless mode
1109 cx: &mut Context<Self>,
1110 ) -> Self {
1111 let next_message_id = MessageId(
1112 serialized
1113 .messages
1114 .last()
1115 .map(|message| message.id.0 + 1)
1116 .unwrap_or(0),
1117 );
1118 let (detailed_summary_tx, detailed_summary_rx) =
1119 postage::watch::channel_with(serialized.detailed_summary_state);
1120
1121 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1122 serialized
1123 .model
1124 .and_then(|model| {
1125 let model = SelectedModel {
1126 provider: model.provider.clone().into(),
1127 model: model.model.clone().into(),
1128 };
1129 registry.select_model(&model, cx)
1130 })
1131 .or_else(|| registry.default_model())
1132 });
1133
1134 let completion_mode = serialized
1135 .completion_mode
1136 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
1137 let profile_id = serialized
1138 .profile
1139 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
1140
1141 let DeserializedToolUse {
1142 tool_use_metadata_by_id,
1143 tool_uses_by_assistant_message,
1144 tool_result_cards,
1145 tool_results,
1146 } = DeserializedToolUse::new(&serialized.messages, &project, &tools, window, cx);
1147
1148 let messages = serialized
1149 .messages
1150 .into_iter()
1151 .map(|message| Message {
1152 id: message.id,
1153 role: message.role,
1154 segments: message
1155 .segments
1156 .into_iter()
1157 .filter_map(|segment| match segment {
1158 SerializedMessageSegment::Text { text } => Some(MessageSegment::Text(text)),
1159 SerializedMessageSegment::Thinking { text, signature } => {
1160 Some(MessageSegment::Thinking { text, signature })
1161 }
1162 SerializedMessageSegment::RedactedThinking { .. } => {
1163 // todo! migrate
1164 None
1165 }
1166 })
1167 .collect(),
1168 loaded_context: LoadedContext {
1169 contexts: Vec::new(),
1170 text: message.context,
1171 images: Vec::new(),
1172 },
1173 creases: message
1174 .creases
1175 .into_iter()
1176 .map(|crease| MessageCrease {
1177 range: crease.start..crease.end,
1178 icon_path: crease.icon_path,
1179 label: crease.label,
1180 context: None,
1181 })
1182 .collect(),
1183 ui_only: false, // UI-only messages are not persisted
1184 })
1185 .collect();
1186
1187 Self {
1188 id,
1189 next_message_id,
1190 thread_messages: messages,
1191 pending_checkpoint: None,
1192 checkpoints_by_message: HashMap::default(),
1193 last_restore_checkpoint: None,
1194 updated_at: serialized.updated_at,
1195 action_log: cx.new(|_| ActionLog::new(project.clone())),
1196 summary: ThreadSummary::Ready(serialized.summary),
1197 last_received_chunk_at: None,
1198 pending_turn: None,
1199 messages: Vec::new(),
1200 pending_summary: Task::ready(None),
1201 detailed_summary_task: Task::ready(None),
1202 detailed_summary_tx,
1203 detailed_summary_rx,
1204 completion_mode,
1205 retry_state: None,
1206 last_prompt_id: PromptId::new(),
1207 project_context,
1208 completion_count: 0,
1209 pending_completions: Vec::new(),
1210 project: project.clone(),
1211 prompt_builder,
1212 tools: tools.clone(),
1213 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
1214 request_token_usage: serialized.request_token_usage,
1215 thread_user_messages: Default::default(),
1216 cumulative_token_usage: serialized.cumulative_token_usage,
1217 exceeded_window_error: None,
1218 tool_use_limit_reached: serialized.tool_use_limit_reached,
1219 feedback: None,
1220 message_feedback: HashMap::default(),
1221 last_auto_capture_at: None,
1222 request_callback: None,
1223 remaining_turns: u32::MAX,
1224 configured_model,
1225 profile: AgentProfile::new(profile_id, tools),
1226
1227 pending_tool_uses_by_id: Default::default(),
1228 tool_use_metadata_by_id,
1229 tool_uses_by_assistant_message,
1230 tool_result_cards,
1231 tool_results,
1232 }
1233 }
1234
1235 pub fn set_request_callback(
1236 &mut self,
1237 callback: impl 'static
1238 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
1239 ) {
1240 self.request_callback = Some(Box::new(callback));
1241 }
1242
1243 pub fn profile(&self) -> &AgentProfile {
1244 &self.profile
1245 }
1246
1247 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
1248 if &id != self.profile.id() {
1249 self.profile = AgentProfile::new(id, self.tools.clone());
1250 cx.emit(ThreadEvent::ProfileChanged);
1251 }
1252 }
1253
1254 pub fn advance_prompt_id(&mut self) {
1255 // todo! remove fn
1256 self.last_prompt_id = PromptId::new();
1257 }
1258
1259 pub fn project_context(&self) -> SharedProjectContext {
1260 self.project_context.clone()
1261 }
1262
1263 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
1264 if self.configured_model.is_none() {
1265 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
1266 }
1267 self.configured_model.clone()
1268 }
1269
1270 pub fn configured_model(&self) -> Option<ConfiguredModel> {
1271 self.configured_model.clone()
1272 }
1273
1274 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
1275 self.configured_model = model;
1276 cx.notify();
1277 }
1278
1279 pub fn completion_mode(&self) -> CompletionMode {
1280 self.completion_mode
1281 }
1282
1283 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
1284 self.completion_mode = mode;
1285 }
1286
1287 pub fn is_generating(&self) -> bool {
1288 self.pending_turn.is_some()
1289 || !self.pending_completions.is_empty()
1290 || !self.all_tools_finished()
1291 }
1292
1293 pub fn queue_state(&self) -> Option<QueueState> {
1294 self.pending_completions
1295 .first()
1296 .map(|pending_completion| pending_completion.queue_state)
1297 }
1298
1299 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
1300 &self.tools
1301 }
1302
1303 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
1304 self.pending_tool_uses_by_id
1305 .values()
1306 .find(|tool_use| &tool_use.id == id)
1307 }
1308
1309 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
1310 self.pending_tool_uses_by_id
1311 .values()
1312 .filter(|tool_use| tool_use.status.needs_confirmation())
1313 }
1314
1315 pub fn has_pending_tool_uses(&self) -> bool {
1316 !self.pending_tool_uses_by_id.is_empty()
1317 }
1318
1319 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
1320 let pending_checkpoint = if self.is_generating() {
1321 return;
1322 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
1323 checkpoint
1324 } else {
1325 return;
1326 };
1327
1328 self.finalize_checkpoint(pending_checkpoint, cx);
1329 }
1330
1331 fn finalize_checkpoint(
1332 &mut self,
1333 pending_checkpoint: ThreadCheckpoint,
1334 cx: &mut Context<Self>,
1335 ) {
1336 let git_store = self.project.read(cx).git_store().clone();
1337 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
1338 cx.spawn(async move |this, cx| match final_checkpoint.await {
1339 Ok(final_checkpoint) => {
1340 let equal = git_store
1341 .update(cx, |store, cx| {
1342 store.compare_checkpoints(
1343 pending_checkpoint.git_checkpoint.clone(),
1344 final_checkpoint.clone(),
1345 cx,
1346 )
1347 })?
1348 .await
1349 .unwrap_or(false);
1350
1351 if !equal {
1352 this.update(cx, |this, cx| {
1353 this.insert_checkpoint(pending_checkpoint, cx)
1354 })?;
1355 }
1356
1357 Ok(())
1358 }
1359 Err(_) => this.update(cx, |this, cx| {
1360 this.insert_checkpoint(pending_checkpoint, cx)
1361 }),
1362 })
1363 .detach();
1364 }
1365
1366 pub fn tool_use_limit_reached(&self) -> bool {
1367 self.tool_use_limit_reached
1368 }
1369
1370 /// Returns whether all of the tool uses have finished running.
1371 pub fn all_tools_finished(&self) -> bool {
1372 // If the only pending tool uses left are the ones with errors, then
1373 // that means that we've finished running all of the pending tools.
1374 self.pending_tool_uses_by_id
1375 .values()
1376 .all(|pending_tool_use| pending_tool_use.status.is_error())
1377 }
1378
1379 /// Returns whether any pending tool uses may perform edits
1380 pub fn has_pending_edit_tool_uses(&self) -> bool {
1381 self.pending_tool_uses_by_id
1382 .values()
1383 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
1384 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
1385 }
1386
1387 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
1388 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
1389 return Vec::new();
1390 };
1391
1392 let mut tool_uses = Vec::new();
1393
1394 for tool_use in tool_uses_for_message.iter() {
1395 let tool_result = self.tool_results.get(&tool_use.id);
1396
1397 let status = (|| {
1398 if let Some(tool_result) = tool_result {
1399 let content = tool_result
1400 .content
1401 .to_str()
1402 .map(|str| str.to_owned().into())
1403 .unwrap_or_default();
1404
1405 return if tool_result.is_error {
1406 ToolUseStatus::Error(content)
1407 } else {
1408 ToolUseStatus::Finished(content)
1409 };
1410 }
1411
1412 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
1413 match pending_tool_use.status {
1414 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
1415 PendingToolUseStatus::NeedsConfirmation { .. } => {
1416 ToolUseStatus::NeedsConfirmation
1417 }
1418 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
1419 PendingToolUseStatus::Error(ref err) => {
1420 ToolUseStatus::Error(err.clone().into())
1421 }
1422 PendingToolUseStatus::InputStillStreaming => {
1423 ToolUseStatus::InputStillStreaming
1424 }
1425 }
1426 } else {
1427 ToolUseStatus::Pending
1428 }
1429 })();
1430
1431 let (icon, needs_confirmation) =
1432 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1433 (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
1434 } else {
1435 (IconName::Cog, false)
1436 };
1437
1438 let tool_ui_label = if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1439 if tool_use.is_input_complete {
1440 tool.ui_text(&tool_use.input).into()
1441 } else {
1442 tool.still_streaming_ui_text(&tool_use.input).into()
1443 }
1444 } else {
1445 format!("Unknown tool {:?}", tool_use.name).into()
1446 };
1447
1448 tool_uses.push(ToolUse {
1449 id: tool_use.id.clone(),
1450 name: tool_use.name.clone().into(),
1451 ui_text: tool_ui_label,
1452 input: tool_use.input.clone(),
1453 status,
1454 icon,
1455 needs_confirmation,
1456 })
1457 }
1458
1459 tool_uses
1460 }
1461
1462 pub fn tool_results_for_message(
1463 &self,
1464 assistant_message_id: MessageId,
1465 ) -> Vec<&LanguageModelToolResult> {
1466 let Some(tool_uses) = self
1467 .tool_uses_by_assistant_message
1468 .get(&assistant_message_id)
1469 else {
1470 return Vec::new();
1471 };
1472
1473 tool_uses
1474 .iter()
1475 .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
1476 .collect()
1477 }
1478
1479 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
1480 self.tool_results.get(id)
1481 }
1482
1483 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
1484 match &self.tool_result(id)?.content {
1485 LanguageModelToolResultContent::Text(text) => Some(text),
1486 LanguageModelToolResultContent::Image(_) => {
1487 // TODO: We should display image
1488 None
1489 }
1490 }
1491 }
1492
1493 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
1494 self.tool_result_cards.get(id).cloned()
1495 }
1496
1497 /// Return tools that are both enabled and supported by the model
1498 pub fn available_tools(
1499 &self,
1500 cx: &App,
1501 model: Arc<dyn LanguageModel>,
1502 ) -> Vec<LanguageModelRequestTool> {
1503 if model.supports_tools() {
1504 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
1505 .into_iter()
1506 .filter_map(|(name, tool)| {
1507 // Skip tools that cannot be supported
1508 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
1509 Some(LanguageModelRequestTool {
1510 name,
1511 description: tool.description(),
1512 input_schema,
1513 })
1514 })
1515 .collect()
1516 } else {
1517 Vec::default()
1518 }
1519 }
1520
1521 /// Serializes this thread into a format for storage or telemetry.
1522 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1523 let initial_project_snapshot = self.initial_project_snapshot.clone();
1524 cx.spawn(async move |this, cx| {
1525 let initial_project_snapshot = initial_project_snapshot.await;
1526 this.read_with(cx, |this, cx| {
1527 SerializedThread {
1528 version: SerializedThread::VERSION.to_string(),
1529 summary: this.summary().or_default(),
1530 updated_at: this.updated_at,
1531 messages: this
1532 .messages()
1533 .filter(|message| !message.ui_only)
1534 .map(|message| SerializedMessage {
1535 id: message.id,
1536 role: message.role,
1537 segments: message
1538 .segments
1539 .iter()
1540 .filter_map(|segment| match segment {
1541 MessageSegment::Text(text) => {
1542 Some(SerializedMessageSegment::Text { text: text.clone() })
1543 }
1544 MessageSegment::Thinking { text, signature } => {
1545 Some(SerializedMessageSegment::Thinking {
1546 text: text.clone(),
1547 signature: signature.clone(),
1548 })
1549 }
1550 MessageSegment::ToolUse { .. } => {
1551 // todo!("change serialization to use the agent's LanguageModelRequestMessages")
1552 None
1553 }
1554 })
1555 .collect(),
1556 tool_uses: this
1557 .tool_uses_for_message(message.id, cx)
1558 .into_iter()
1559 .map(|tool_use| SerializedToolUse {
1560 id: tool_use.id,
1561 name: tool_use.name,
1562 input: tool_use.input,
1563 })
1564 .collect(),
1565 tool_results: this
1566 .tool_results_for_message(message.id)
1567 .into_iter()
1568 .map(|tool_result| SerializedToolResult {
1569 tool_use_id: tool_result.tool_use_id.clone(),
1570 is_error: tool_result.is_error,
1571 content: tool_result.content.clone(),
1572 output: tool_result.output.clone(),
1573 })
1574 .collect(),
1575 context: message.loaded_context.text.clone(),
1576 creases: message
1577 .creases
1578 .iter()
1579 .map(|crease| SerializedCrease {
1580 start: crease.range.start,
1581 end: crease.range.end,
1582 icon_path: crease.icon_path.clone(),
1583 label: crease.label.clone(),
1584 })
1585 .collect(),
1586 })
1587 .collect(),
1588 initial_project_snapshot,
1589 cumulative_token_usage: this.cumulative_token_usage,
1590 request_token_usage: this.request_token_usage.clone(),
1591 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1592 exceeded_window_error: this.exceeded_window_error.clone(),
1593 model: this
1594 .configured_model
1595 .as_ref()
1596 .map(|model| SerializedLanguageModel {
1597 provider: model.provider.id().0.to_string(),
1598 model: model.model.id().0.to_string(),
1599 }),
1600 completion_mode: Some(this.completion_mode),
1601 tool_use_limit_reached: this.tool_use_limit_reached,
1602 profile: Some(this.profile.id().clone()),
1603 }
1604 })
1605 })
1606 }
1607
1608 pub fn remaining_turns(&self) -> u32 {
1609 self.remaining_turns
1610 }
1611
1612 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1613 self.remaining_turns = remaining_turns;
1614 }
1615
1616 pub fn cancel(&mut self) -> Option<Task<Result<()>>> {
1617 self.pending_turn.take().map(|turn| {
1618 turn.cancel_tx.send(()).ok();
1619 turn.task
1620 })
1621 }
1622
1623 pub fn truncate(&mut self, old_message_id: MessageId, cx: &mut Context<Self>) {
1624 let Some(message_ix) = self
1625 .thread_messages
1626 .iter()
1627 .rposition(|message| message.id == old_message_id)
1628 else {
1629 return;
1630 };
1631 for deleted_message in self.thread_messages.drain(message_ix..) {
1632 self.checkpoints_by_message.remove(&deleted_message.id);
1633 }
1634 cx.notify();
1635
1636 if let Some(old_message_ix) = self.thread_user_messages.remove(&old_message_id) {
1637 self.messages.truncate(old_message_ix);
1638 }
1639 }
1640
1641 pub fn send_message2(
1642 &mut self,
1643 user_message: impl Into<UserMessageParams>,
1644 model: Arc<dyn LanguageModel>,
1645 window: Option<AnyWindowHandle>,
1646 cx: &mut Context<Self>,
1647 ) -> LocalBoxFuture<'static, Result<Turn>> {
1648 self.advance_prompt_id();
1649
1650 let user_message = user_message.into();
1651 let prev_turn = self.cancel();
1652 let (cancel_tx, cancel_rx) = oneshot::channel();
1653 let (turn_tx, turn_rx) = oneshot::channel();
1654 self.pending_turn = Some(PendingTurn {
1655 task: cx.spawn(async move |this, cx| {
1656 if let Some(prev_turn) = prev_turn {
1657 prev_turn.await?;
1658 }
1659
1660 let user_message_id =
1661 this.update(cx, |this, cx| this.insert_user_message(user_message, cx))?;
1662 let (response_events_tx, response_events_rx) = mpsc::unbounded();
1663 turn_tx
1664 .send(Turn {
1665 user_message_id,
1666 response_events: response_events_rx.boxed_local(),
1667 })
1668 .ok();
1669
1670 Self::turn_loop2(
1671 &this,
1672 model,
1673 CompletionIntent::UserPrompt,
1674 cancel_rx,
1675 response_events_tx,
1676 window,
1677 cx,
1678 )
1679 .await?;
1680
1681 this.update(cx, |this, _cx| this.pending_turn.take()).ok();
1682
1683 Ok(())
1684 }),
1685 cancel_tx,
1686 });
1687
1688 async move { turn_rx.await.map_err(|_| anyhow!("Turn loop failed")) }.boxed_local()
1689 }
1690
1691 async fn turn_loop2(
1692 this: &WeakEntity<Self>,
1693 model: Arc<dyn LanguageModel>,
1694 mut intent: CompletionIntent,
1695 mut cancel_rx: oneshot::Receiver<()>,
1696 mut response_events_tx: mpsc::UnboundedSender<Result<ResponseEvent>>,
1697 window: Option<AnyWindowHandle>,
1698 cx: &mut AsyncApp,
1699 ) -> Result<()> {
1700 struct RetryState {
1701 attempts: u8,
1702 custom_delay: Option<Duration>,
1703 }
1704 let mut retry_state: Option<RetryState> = None;
1705
1706 struct PendingAssistantMessage {
1707 chunks: VecDeque<PendingAssistantMessageChunk>,
1708 }
1709
1710 impl PendingAssistantMessage {
1711 fn push_text(&mut self, text: String) {
1712 if let Some(PendingAssistantMessageChunk::Text(existing_text)) =
1713 self.chunks.back_mut()
1714 {
1715 existing_text.push_str(&text);
1716 } else {
1717 self.chunks
1718 .push_back(PendingAssistantMessageChunk::Text(text));
1719 }
1720 }
1721
1722 fn push_thinking(&mut self, text: String, signature: Option<String>) {
1723 if let Some(PendingAssistantMessageChunk::Thinking {
1724 text: existing_text,
1725 signature: existing_signature,
1726 }) = self.chunks.back_mut()
1727 {
1728 *existing_signature = existing_signature.take().or(signature);
1729 existing_text.push_str(&text);
1730 } else {
1731 self.chunks
1732 .push_back(PendingAssistantMessageChunk::Thinking { text, signature });
1733 }
1734 }
1735 }
1736
1737 enum PendingAssistantMessageChunk {
1738 Text(String),
1739 Thinking {
1740 text: String,
1741 signature: Option<String>,
1742 },
1743 RedactedThinking {
1744 data: String,
1745 },
1746 ToolCall(PendingAssistantToolCall),
1747 }
1748
1749 struct PendingAssistantToolCall {
1750 request: LanguageModelToolUse,
1751 output: oneshot::Receiver<Result<ToolResultOutput>>,
1752 }
1753
1754 loop {
1755 let mut segments = Vec::new();
1756 let mut assistant_message = PendingAssistantMessage {
1757 chunks: VecDeque::new(),
1758 };
1759
1760 let send = async {
1761 if let Some(retry_state) = retry_state.as_ref() {
1762 let delay = retry_state.custom_delay.unwrap_or_else(|| {
1763 BASE_RETRY_DELAY * 2_u32.pow((retry_state.attempts - 1) as u32)
1764 });
1765 cx.background_executor().timer(delay).await;
1766 }
1767
1768 let request = this.update(cx, |this, cx| this.build_request(&model, intent, cx))?;
1769 let mut events = model.stream_completion(request.clone(), cx).await?;
1770
1771 while let Some(event) = events.next().await {
1772 let event = event?;
1773 match event {
1774 LanguageModelCompletionEvent::StartMessage { .. } => {
1775 // no-op, todo!("do we wanna insert a new message here?")
1776 }
1777 LanguageModelCompletionEvent::Text(chunk) => {
1778 response_events_tx
1779 .unbounded_send(Ok(ResponseEvent::Text(chunk.clone())));
1780 assistant_message.push_text(chunk);
1781 }
1782 LanguageModelCompletionEvent::Thinking { text, signature } => {
1783 response_events_tx
1784 .unbounded_send(Ok(ResponseEvent::Thinking(text.clone())));
1785 assistant_message.push_thinking(text, signature);
1786 }
1787 LanguageModelCompletionEvent::RedactedThinking { data } => {
1788 assistant_message
1789 .chunks
1790 .push_back(PendingAssistantMessageChunk::RedactedThinking { data });
1791 }
1792 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1793 match this
1794 .read_with(cx, |this, cx| this.tool_for_name(&tool_use.name, cx))?
1795 {
1796 Ok(tool) => {
1797 if tool_use.is_input_complete {
1798 let (output_tx, output_rx) = oneshot::channel();
1799 let mut request = request.clone();
1800 // todo!("add the pending assistant message (excluding the tool calls)")
1801 response_events_tx.unbounded_send(Ok(
1802 ResponseEvent::ToolCall {
1803 id: tool_use.id,
1804 needs_confirmation: cx.update(|cx| {
1805 tool.needs_confirmation(&tool_use.input, cx)
1806 })?,
1807 label: tool.ui_text(&tool_use.input),
1808 run: Box::new({
1809 let project = this
1810 .read_with(cx, |this, _| {
1811 this.project.clone()
1812 })?;
1813 let action_log = this
1814 .read_with(cx, |this, _| {
1815 this.action_log.clone()
1816 })?;
1817 move |window, cx| {
1818 let assistant_tool::ToolResult {
1819 output,
1820 card,
1821 } = tool.run(
1822 tool_use.input,
1823 Arc::new(request),
1824 project,
1825 action_log,
1826 model,
1827 window,
1828 cx,
1829 );
1830
1831 ToolCallResult {
1832 task: cx.foreground_executor().spawn(
1833 async move {
1834 match output.await {
1835 Ok(output) => {
1836 output_tx
1837 .send(Ok(output))
1838 .ok();
1839 Ok(())
1840 }
1841 Err(error) => {
1842 let error =
1843 Arc::new(error);
1844 output_tx
1845 .send(Err(anyhow!(
1846 error.clone()
1847 )))
1848 .ok();
1849 Err(anyhow!(error))
1850 }
1851 }
1852 },
1853 ),
1854 card,
1855 }
1856 }
1857 }),
1858 },
1859 ));
1860 assistant_message.chunks.push_back(
1861 PendingAssistantMessageChunk::ToolCall(
1862 PendingAssistantToolCall {
1863 request: tool_use,
1864 output: output_rx,
1865 },
1866 ),
1867 );
1868 } else {
1869 response_events_tx.unbounded_send(Ok(
1870 ResponseEvent::ToolCallChunk {
1871 id: tool_use.id,
1872 label: tool
1873 .still_streaming_ui_text(&tool_use.input),
1874 input: tool_use.input,
1875 },
1876 ));
1877 }
1878 }
1879 Err(error) => {
1880 response_events_tx.unbounded_send(Ok(
1881 ResponseEvent::InvalidToolCallChunk(tool_use.clone()),
1882 ));
1883 if tool_use.is_input_complete {
1884 let (output_tx, output_rx) = oneshot::channel();
1885 output_tx.send(Err(error)).unwrap();
1886 assistant_message.chunks.push_back(
1887 PendingAssistantMessageChunk::ToolCall(
1888 PendingAssistantToolCall {
1889 request: tool_use,
1890 output: output_rx,
1891 },
1892 ),
1893 );
1894 }
1895 }
1896 }
1897 }
1898 LanguageModelCompletionEvent::UsageUpdate(_token_usage) => {
1899 // todo!
1900 }
1901 LanguageModelCompletionEvent::StatusUpdate(_completion_request_status) => {
1902 // todo!
1903 }
1904 LanguageModelCompletionEvent::Stop(StopReason::EndTurn) => {
1905 // todo!
1906 }
1907 LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) => {
1908 // todo!
1909 }
1910 LanguageModelCompletionEvent::Stop(StopReason::Refusal) => {
1911 // todo!
1912 }
1913 LanguageModelCompletionEvent::Stop(StopReason::ToolUse) => {}
1914 }
1915 }
1916
1917 while let Some(chunk) = assistant_message.chunks.pop_front() {
1918 match chunk {
1919 PendingAssistantMessageChunk::Text(_) => todo!(),
1920 PendingAssistantMessageChunk::Thinking { text, signature } => todo!(),
1921 PendingAssistantMessageChunk::RedactedThinking { data } => todo!(),
1922 PendingAssistantMessageChunk::ToolCall(pending_assistant_tool_call) => {
1923 pending_assistant_tool_call.output.await;
1924 }
1925 }
1926
1927 let (tool_result, thread_result) = pending_tool_use.result().await;
1928 this.update(cx, |thread, cx| {
1929 thread.set_tool_call_result(
1930 pending_tool_use.index_in_message,
1931 thread_result,
1932 cx,
1933 )
1934 })?;
1935 assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
1936 tool_results_message
1937 .content
1938 .push(MessageContent::ToolResult(tool_result));
1939 }
1940
1941 anyhow::Ok(())
1942 }
1943 .boxed_local();
1944
1945 enum SendStatus {
1946 Canceled,
1947 Finished(Result<()>),
1948 }
1949
1950 let status = match futures::future::select(&mut cancel_rx, send).await {
1951 Either::Left(_) => SendStatus::Canceled,
1952 Either::Right((result, _)) => SendStatus::Finished(result),
1953 };
1954
1955 match status {
1956 SendStatus::Canceled => {
1957 for pending_tool_use in pending_tool_uses {
1958 tool_results_message
1959 .content
1960 .push(MessageContent::ToolResult(LanguageModelToolResult {
1961 tool_use_id: pending_tool_use.request.id.clone(),
1962 tool_name: pending_tool_use.request.name.clone(),
1963 is_error: true,
1964 content: LanguageModelToolResultContent::Text(
1965 "<User cancelled tool use>".into(),
1966 ),
1967 output: None,
1968 }));
1969 assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
1970 }
1971
1972 this.update(cx, |this, _cx| {
1973 if !assistant_message.content.is_empty() {
1974 this.messages.push(assistant_message);
1975 }
1976
1977 if !tool_results_message.content.is_empty() {
1978 this.messages.push(tool_results_message);
1979 }
1980 })?;
1981
1982 break;
1983 }
1984 SendStatus::Finished(result) => {
1985 for mut pending_tool_use in pending_tool_uses {
1986 let (tool_result, thread_result) = pending_tool_use.result().await;
1987 this.update(cx, |thread, cx| {
1988 thread.set_tool_call_result(
1989 pending_tool_use.index_in_message,
1990 thread_result,
1991 cx,
1992 )
1993 })?;
1994 assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
1995 tool_results_message
1996 .content
1997 .push(MessageContent::ToolResult(tool_result));
1998 }
1999
2000 match result {
2001 Ok(_) => {
2002 retry_state = None;
2003 }
2004 Err(error) => {
2005 let mut retry = |custom_delay: Option<Duration>| -> bool {
2006 let retry_state = retry_state.get_or_insert_with(|| RetryState {
2007 attempts: 0,
2008 custom_delay,
2009 });
2010 retry_state.attempts += 1;
2011 retry_state.attempts <= MAX_RETRY_ATTEMPTS
2012 };
2013
2014 if error.is::<PaymentRequiredError>() {
2015 // todo!
2016 // cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
2017 } else if let Some(_error) =
2018 error.downcast_ref::<ModelRequestLimitReachedError>()
2019 {
2020 // todo!
2021 // cx.emit(ThreadEvent::ShowError(
2022 // ThreadError::ModelRequestLimitReached { plan: error.plan },
2023 // ));
2024 } else if let Some(completion_error) =
2025 error.downcast_ref::<LanguageModelCompletionError>()
2026 {
2027 match completion_error {
2028 LanguageModelCompletionError::RateLimitExceeded {
2029 retry_after,
2030 } => {
2031 if !retry(Some(*retry_after)) {
2032 break;
2033 }
2034 }
2035 LanguageModelCompletionError::Overloaded => {
2036 if !retry(None) {
2037 break;
2038 }
2039 }
2040 LanguageModelCompletionError::ApiInternalServerError => {
2041 if !retry(None) {
2042 break;
2043 }
2044 // todo!
2045 }
2046 _ => {
2047 // todo!(emit_generic_error(error, cx);)
2048 break;
2049 }
2050 }
2051 } else if let Some(known_error) =
2052 error.downcast_ref::<LanguageModelKnownError>()
2053 {
2054 match known_error {
2055 LanguageModelKnownError::ContextWindowLimitExceeded {
2056 tokens: _,
2057 } => {
2058 // todo!
2059 // this.exceeded_window_error =
2060 // Some(ExceededWindowError {
2061 // model_id: model.id(),
2062 // token_count: *tokens,
2063 // });
2064 // cx.notify();
2065 break;
2066 }
2067 LanguageModelKnownError::RateLimitExceeded { retry_after } => {
2068 // let provider_name = model.provider_name();
2069 // let error_message = format!(
2070 // "{}'s API rate limit exceeded",
2071 // provider_name.0.as_ref()
2072 // );
2073 if !retry(Some(*retry_after)) {
2074 // todo! show err
2075 break;
2076 }
2077 }
2078 LanguageModelKnownError::Overloaded => {
2079 //todo!
2080 // let provider_name = model.provider_name();
2081 // let error_message = format!(
2082 // "{}'s API servers are overloaded right now",
2083 // provider_name.0.as_ref()
2084 // );
2085
2086 if !retry(None) {
2087 // todo! show err
2088 break;
2089 }
2090 }
2091 LanguageModelKnownError::ApiInternalServerError => {
2092 // let provider_name = model.provider_name();
2093 // let error_message = format!(
2094 // "{}'s API server reported an internal server error",
2095 // provider_name.0.as_ref()
2096 // );
2097
2098 if !retry(None) {
2099 break;
2100 }
2101 }
2102 LanguageModelKnownError::ReadResponseError(_)
2103 | LanguageModelKnownError::DeserializeResponse(_)
2104 | LanguageModelKnownError::UnknownResponseFormat(_) => {
2105 // In the future we will attempt to re-roll response, but only once
2106 // todo!(emit_generic_error(error, cx);)
2107 break;
2108 }
2109 }
2110 } else {
2111 // todo!(emit_generic_error(error, cx));
2112 break;
2113 }
2114 }
2115 }
2116
2117 let done = this.update(cx, |this, cx| {
2118 let done = if assistant_message.content.is_empty() {
2119 true
2120 } else {
2121 this.messages.push(assistant_message);
2122 if tool_results_message.content.is_empty() {
2123 true
2124 } else {
2125 this.messages.push(tool_results_message);
2126 false
2127 }
2128 };
2129
2130 let summary_pending = matches!(this.summary(), ThreadSummary::Pending);
2131
2132 if summary_pending && (done || this.messages.len() > 6) {
2133 this.summarize(cx);
2134 }
2135
2136 done
2137 })?;
2138
2139 if done && retry_state.is_none() {
2140 break;
2141 } else {
2142 intent = CompletionIntent::ToolResults;
2143 }
2144 }
2145 }
2146 }
2147
2148 Ok(())
2149 }
2150
2151 pub fn send_message(
2152 &mut self,
2153 params: impl Into<UserMessageParams>,
2154 model: Arc<dyn LanguageModel>,
2155 window: Option<AnyWindowHandle>,
2156 cx: &mut Context<Self>,
2157 ) -> MessageId {
2158 let message_id = self.insert_user_message(params.into(), cx);
2159 self.run_turn(model, window, cx);
2160 message_id
2161 }
2162
2163 pub fn send_continue_message(
2164 &mut self,
2165 model: Arc<dyn LanguageModel>,
2166 window: Option<AnyWindowHandle>,
2167 cx: &mut Context<Self>,
2168 ) {
2169 self.insert_request_user_message(&"Continue where you left off".into());
2170 self.run_turn(model, window, cx);
2171 }
2172
2173 fn run_turn(
2174 &mut self,
2175 model: Arc<dyn LanguageModel>,
2176 window: Option<AnyWindowHandle>,
2177 cx: &mut Context<Self>,
2178 ) {
2179 self.advance_prompt_id();
2180
2181 let prev_turn = self.cancel();
2182 let (cancel_tx, cancel_rx) = oneshot::channel();
2183 self.pending_turn = Some(PendingTurn {
2184 task: cx.spawn(async move |this, cx| {
2185 if let Some(prev_turn) = prev_turn {
2186 prev_turn.await?;
2187 }
2188
2189 Self::turn_loop(
2190 &this,
2191 model,
2192 CompletionIntent::UserPrompt,
2193 cancel_rx,
2194 window,
2195 cx,
2196 )
2197 .await?;
2198
2199 this.update(cx, |this, _cx| this.pending_turn.take()).ok();
2200
2201 Ok(())
2202 }),
2203 cancel_tx,
2204 });
2205 }
2206
2207 async fn turn_loop(
2208 this: &WeakEntity<Self>,
2209 model: Arc<dyn LanguageModel>,
2210 mut intent: CompletionIntent,
2211 mut cancel_rx: oneshot::Receiver<()>,
2212 window: Option<AnyWindowHandle>,
2213 cx: &mut AsyncApp,
2214 ) -> Result<()> {
2215 struct RetryState {
2216 attempts: u8,
2217 custom_delay: Option<Duration>,
2218 }
2219 let mut retry_state: Option<RetryState> = None;
2220
2221 loop {
2222 let mut assistant_message = LanguageModelRequestMessage {
2223 role: Role::Assistant,
2224 content: Vec::new(),
2225 cache: false,
2226 };
2227 let mut tool_results_message = LanguageModelRequestMessage {
2228 role: Role::User,
2229 content: Vec::new(),
2230 cache: false,
2231 };
2232 let mut pending_tool_uses = Vec::new();
2233
2234 let send = async {
2235 if let Some(retry_state) = retry_state.as_ref() {
2236 let delay = retry_state.custom_delay.unwrap_or_else(|| {
2237 BASE_RETRY_DELAY * 2_u32.pow((retry_state.attempts - 1) as u32)
2238 });
2239 cx.background_executor().timer(delay).await;
2240 }
2241
2242 let request = this.update(cx, |this, cx| this.build_request(&model, intent, cx))?;
2243 let mut events = model.stream_completion(request.clone(), cx).await?;
2244
2245 while let Some(event) = events.next().await {
2246 let event = event?;
2247 match event {
2248 LanguageModelCompletionEvent::StartMessage { .. } => {
2249 this.update(cx, |agent, cx| {
2250 agent.insert_assistant_message(vec![], cx)
2251 })?;
2252 }
2253 LanguageModelCompletionEvent::Text(chunk) => {
2254 this.update(cx, |this, cx| {
2255 this.push_assistant_message_segment(
2256 MessageSegment::Text(chunk.clone()),
2257 cx,
2258 );
2259 })?;
2260 assistant_message.push(MessageContent::Text(chunk));
2261 }
2262 LanguageModelCompletionEvent::Thinking { text, signature } => {
2263 this.update(cx, |thread, cx| {
2264 thread.push_assistant_message_segment(
2265 MessageSegment::Thinking {
2266 text: text.clone(),
2267 signature: signature.clone(),
2268 },
2269 cx,
2270 );
2271 })?;
2272 assistant_message.push(MessageContent::Thinking { text, signature });
2273 }
2274 LanguageModelCompletionEvent::RedactedThinking { data } => {
2275 assistant_message.push(MessageContent::RedactedThinking(data));
2276 }
2277 LanguageModelCompletionEvent::ToolUse(tool_use) => {
2278 // todo!("update tool card")
2279 // todo! tool input streaming
2280 if tool_use.is_input_complete {
2281 let mut pending_request = request.clone();
2282 pending_request.messages.push(assistant_message.clone());
2283
2284 match this.read_with(cx, |this, cx| {
2285 this.tool_for_name(&tool_use.name, cx)
2286 })? {
2287 Ok(tool) => {
2288 // todo!("handle confirmation")
2289 // let confirmed = if tool.needs_confirmation(&tool_use.input, cx)
2290 // && !AgentSettings::get_global(cx).always_allow_tool_actions
2291 // {
2292 // thread.update(cx, |thread,cx| thread.gimme_confirmation()).await;
2293 // } else {
2294 // true
2295 // };
2296 let tool_result = this.update(cx, |this, cx| {
2297 tool.run(
2298 tool_use.input.clone(),
2299 Arc::new(pending_request),
2300 this.project.clone(),
2301 this.action_log(),
2302 model.clone(),
2303 window,
2304 cx,
2305 )
2306 })?;
2307 let index = this.update(cx, |thread, cx| {
2308 thread.push_tool_call(
2309 tool_use.name.clone(),
2310 tool_use.input.clone(),
2311 tool_result.card,
2312 cx,
2313 )
2314 })?;
2315 pending_tool_uses.push(PendingToolUse2 {
2316 index_in_message: index,
2317 output: Some(tool_result.output),
2318 request: tool_use,
2319 });
2320 }
2321 Err(error) => {
2322 let index = this.update(cx, |thread, cx| {
2323 thread.push_tool_call(
2324 tool_use.name.clone(),
2325 tool_use.input.clone(),
2326 None,
2327 cx,
2328 )
2329 })?;
2330 pending_tool_uses.push(PendingToolUse2 {
2331 index_in_message: index,
2332 request: tool_use,
2333 output: Some(Task::ready(Err(error))),
2334 });
2335 }
2336 }
2337 }
2338 }
2339 LanguageModelCompletionEvent::UsageUpdate(_token_usage) => {
2340 // todo!
2341 }
2342 LanguageModelCompletionEvent::StatusUpdate(_completion_request_status) => {
2343 // todo!
2344 }
2345 LanguageModelCompletionEvent::Stop(StopReason::EndTurn) => {
2346 // todo!
2347 }
2348 LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) => {
2349 // todo!
2350 }
2351 LanguageModelCompletionEvent::Stop(StopReason::Refusal) => {
2352 // todo!
2353 }
2354 LanguageModelCompletionEvent::Stop(StopReason::ToolUse) => {}
2355 }
2356 }
2357
2358 while let Some(mut pending_tool_use) = pending_tool_uses.pop() {
2359 let (tool_result, thread_result) = pending_tool_use.result().await;
2360 this.update(cx, |thread, cx| {
2361 thread.set_tool_call_result(
2362 pending_tool_use.index_in_message,
2363 thread_result,
2364 cx,
2365 )
2366 })?;
2367 assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
2368 tool_results_message
2369 .content
2370 .push(MessageContent::ToolResult(tool_result));
2371 }
2372
2373 anyhow::Ok(())
2374 }
2375 .boxed_local();
2376
2377 enum SendStatus {
2378 Canceled,
2379 Finished(Result<()>),
2380 }
2381
2382 let status = match futures::future::select(&mut cancel_rx, send).await {
2383 Either::Left(_) => SendStatus::Canceled,
2384 Either::Right((result, _)) => SendStatus::Finished(result),
2385 };
2386
2387 match status {
2388 SendStatus::Canceled => {
2389 for pending_tool_use in pending_tool_uses {
2390 tool_results_message
2391 .content
2392 .push(MessageContent::ToolResult(LanguageModelToolResult {
2393 tool_use_id: pending_tool_use.request.id.clone(),
2394 tool_name: pending_tool_use.request.name.clone(),
2395 is_error: true,
2396 content: LanguageModelToolResultContent::Text(
2397 "<User cancelled tool use>".into(),
2398 ),
2399 output: None,
2400 }));
2401 assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
2402 }
2403
2404 this.update(cx, |this, _cx| {
2405 if !assistant_message.content.is_empty() {
2406 this.messages.push(assistant_message);
2407 }
2408
2409 if !tool_results_message.content.is_empty() {
2410 this.messages.push(tool_results_message);
2411 }
2412 })?;
2413
2414 break;
2415 }
2416 SendStatus::Finished(result) => {
2417 for mut pending_tool_use in pending_tool_uses {
2418 let (tool_result, thread_result) = pending_tool_use.result().await;
2419 this.update(cx, |thread, cx| {
2420 thread.set_tool_call_result(
2421 pending_tool_use.index_in_message,
2422 thread_result,
2423 cx,
2424 )
2425 })?;
2426 assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
2427 tool_results_message
2428 .content
2429 .push(MessageContent::ToolResult(tool_result));
2430 }
2431
2432 match result {
2433 Ok(_) => {
2434 retry_state = None;
2435 }
2436 Err(error) => {
2437 let mut retry = |custom_delay: Option<Duration>| -> bool {
2438 let retry_state = retry_state.get_or_insert_with(|| RetryState {
2439 attempts: 0,
2440 custom_delay,
2441 });
2442 retry_state.attempts += 1;
2443 retry_state.attempts <= MAX_RETRY_ATTEMPTS
2444 };
2445
2446 if error.is::<PaymentRequiredError>() {
2447 // todo!
2448 // cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
2449 } else if let Some(_error) =
2450 error.downcast_ref::<ModelRequestLimitReachedError>()
2451 {
2452 // todo!
2453 // cx.emit(ThreadEvent::ShowError(
2454 // ThreadError::ModelRequestLimitReached { plan: error.plan },
2455 // ));
2456 } else if let Some(completion_error) =
2457 error.downcast_ref::<LanguageModelCompletionError>()
2458 {
2459 match completion_error {
2460 LanguageModelCompletionError::RateLimitExceeded {
2461 retry_after,
2462 } => {
2463 if !retry(Some(*retry_after)) {
2464 break;
2465 }
2466 }
2467 LanguageModelCompletionError::Overloaded => {
2468 if !retry(None) {
2469 break;
2470 }
2471 }
2472 LanguageModelCompletionError::ApiInternalServerError => {
2473 if !retry(None) {
2474 break;
2475 }
2476 // todo!
2477 }
2478 _ => {
2479 // todo!(emit_generic_error(error, cx);)
2480 break;
2481 }
2482 }
2483 } else if let Some(known_error) =
2484 error.downcast_ref::<LanguageModelKnownError>()
2485 {
2486 match known_error {
2487 LanguageModelKnownError::ContextWindowLimitExceeded {
2488 tokens: _,
2489 } => {
2490 // todo!
2491 // this.exceeded_window_error =
2492 // Some(ExceededWindowError {
2493 // model_id: model.id(),
2494 // token_count: *tokens,
2495 // });
2496 // cx.notify();
2497 break;
2498 }
2499 LanguageModelKnownError::RateLimitExceeded { retry_after } => {
2500 // let provider_name = model.provider_name();
2501 // let error_message = format!(
2502 // "{}'s API rate limit exceeded",
2503 // provider_name.0.as_ref()
2504 // );
2505 if !retry(Some(*retry_after)) {
2506 // todo! show err
2507 break;
2508 }
2509 }
2510 LanguageModelKnownError::Overloaded => {
2511 //todo!
2512 // let provider_name = model.provider_name();
2513 // let error_message = format!(
2514 // "{}'s API servers are overloaded right now",
2515 // provider_name.0.as_ref()
2516 // );
2517
2518 if !retry(None) {
2519 // todo! show err
2520 break;
2521 }
2522 }
2523 LanguageModelKnownError::ApiInternalServerError => {
2524 // let provider_name = model.provider_name();
2525 // let error_message = format!(
2526 // "{}'s API server reported an internal server error",
2527 // provider_name.0.as_ref()
2528 // );
2529
2530 if !retry(None) {
2531 break;
2532 }
2533 }
2534 LanguageModelKnownError::ReadResponseError(_)
2535 | LanguageModelKnownError::DeserializeResponse(_)
2536 | LanguageModelKnownError::UnknownResponseFormat(_) => {
2537 // In the future we will attempt to re-roll response, but only once
2538 // todo!(emit_generic_error(error, cx);)
2539 break;
2540 }
2541 }
2542 } else {
2543 // todo!(emit_generic_error(error, cx));
2544 break;
2545 }
2546 }
2547 }
2548
2549 let done = this.update(cx, |this, cx| {
2550 let done = if assistant_message.content.is_empty() {
2551 true
2552 } else {
2553 this.messages.push(assistant_message);
2554 if tool_results_message.content.is_empty() {
2555 true
2556 } else {
2557 this.messages.push(tool_results_message);
2558 false
2559 }
2560 };
2561
2562 let summary_pending = matches!(this.summary(), ThreadSummary::Pending);
2563
2564 if summary_pending && (done || this.messages.len() > 6) {
2565 this.summarize(cx);
2566 }
2567
2568 done
2569 })?;
2570
2571 if done && retry_state.is_none() {
2572 break;
2573 } else {
2574 intent = CompletionIntent::ToolResults;
2575 }
2576 }
2577 }
2578 }
2579
2580 Ok(())
2581 }
2582
2583 fn build_request(
2584 &mut self,
2585 model: &Arc<dyn LanguageModel>,
2586 intent: CompletionIntent,
2587 cx: &mut Context<Self>,
2588 ) -> LanguageModelRequest {
2589 let mode = if model.supports_burn_mode() {
2590 Some(self.completion_mode.into())
2591 } else {
2592 Some(CompletionMode::Normal.into())
2593 };
2594
2595 let available_tools = self.available_tools(cx, model.clone());
2596 let available_tool_names = available_tools
2597 .iter()
2598 .map(|tool| tool.name.clone())
2599 .collect();
2600
2601 let mut request = LanguageModelRequest {
2602 thread_id: Some(self.id().to_string()),
2603 prompt_id: Some(self.last_prompt_id.to_string()),
2604 intent: Some(intent),
2605 mode,
2606 messages: vec![],
2607 tools: available_tools,
2608 tool_choice: None,
2609 stop: Vec::new(),
2610 temperature: AgentSettings::temperature_for_model(&model, cx),
2611 };
2612
2613 let model_context = &ModelContext {
2614 available_tools: available_tool_names,
2615 };
2616
2617 // todo!("should we cache the system prompt and append a message when it changes, as opposed to replacing it?")
2618 if let Some(project_context) = self.project_context.borrow().as_ref() {
2619 match self
2620 .prompt_builder
2621 .generate_assistant_system_prompt(project_context, model_context)
2622 {
2623 Err(err) => {
2624 let message = format!("{err:?}").into();
2625 log::error!("{message}");
2626 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
2627 header: "Error generating system prompt".into(),
2628 message,
2629 }));
2630 }
2631 Ok(system_prompt) => {
2632 request.messages.push(LanguageModelRequestMessage {
2633 role: Role::System,
2634 content: vec![MessageContent::Text(system_prompt)],
2635 cache: true,
2636 });
2637 }
2638 }
2639 } else {
2640 let message = "Context for system prompt unexpectedly not ready.".into();
2641 log::error!("{message}");
2642 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
2643 header: "Error generating system prompt".into(),
2644 message,
2645 }));
2646 }
2647
2648 request.messages.extend(self.messages.iter().cloned());
2649 request
2650 }
2651
2652 // todo! only used in eval. make private somehow?
2653 pub fn insert_user_message(
2654 &mut self,
2655 params: impl Into<UserMessageParams>,
2656 cx: &mut Context<Self>,
2657 ) -> MessageId {
2658 let params = params.into();
2659 let req_ix = self.insert_request_user_message(¶ms);
2660
2661 if !params.context.referenced_buffers.is_empty() {
2662 self.action_log.clone().update(cx, |log, cx| {
2663 for buffer in params.context.referenced_buffers {
2664 log.buffer_read(buffer, cx);
2665 }
2666 });
2667 }
2668
2669 let message_id = self.insert_message(
2670 Role::User,
2671 vec![MessageSegment::Text(params.text)],
2672 params.context.loaded_context,
2673 params.creases,
2674 cx,
2675 );
2676
2677 self.thread_user_messages.insert(message_id, req_ix);
2678
2679 if let Some(git_checkpoint) = params.checkpoint {
2680 self.pending_checkpoint = Some(ThreadCheckpoint {
2681 message_id,
2682 git_checkpoint,
2683 })
2684 }
2685
2686 self.auto_capture_telemetry(cx);
2687
2688 message_id
2689 }
2690
2691 fn insert_request_user_message(&mut self, params: &UserMessageParams) -> usize {
2692 let mut request_message = LanguageModelRequestMessage {
2693 role: Role::User,
2694 content: vec![],
2695 cache: false,
2696 };
2697
2698 params
2699 .context
2700 .loaded_context
2701 .add_to_request_message(&mut request_message);
2702 request_message
2703 .content
2704 .push(MessageContent::Text(params.text.clone()));
2705
2706 let ix = self.messages.len();
2707 self.messages.push(request_message);
2708 ix
2709 }
2710
2711 pub fn send_to_model(
2712 &mut self,
2713 model: Arc<dyn LanguageModel>,
2714 intent: CompletionIntent,
2715 window: Option<AnyWindowHandle>,
2716 cx: &mut Context<Self>,
2717 ) {
2718 if self.remaining_turns == 0 {
2719 return;
2720 }
2721
2722 self.remaining_turns -= 1;
2723
2724 let request = self.to_completion_request(model.clone(), intent, cx);
2725
2726 self.stream_completion(request, model, intent, window, cx);
2727 }
2728
2729 pub fn used_tools_since_last_user_message(&self, _cx: &App) -> bool {
2730 for message in self.messages().rev() {
2731 let message_has_tool_results = self
2732 .tool_uses_by_assistant_message
2733 .get(&message.id)
2734 .map_or(false, |results| !results.is_empty());
2735 if message_has_tool_results {
2736 return true;
2737 } else if message.role == Role::User {
2738 return false;
2739 }
2740 }
2741
2742 false
2743 }
2744
2745 pub fn to_completion_request(
2746 &self,
2747 model: Arc<dyn LanguageModel>,
2748 intent: CompletionIntent,
2749 cx: &mut Context<Self>,
2750 ) -> LanguageModelRequest {
2751 let mut request = LanguageModelRequest {
2752 thread_id: Some(self.id().to_string()),
2753 prompt_id: Some(self.last_prompt_id.to_string()),
2754 intent: Some(intent),
2755 mode: None,
2756 messages: vec![],
2757 tools: Vec::new(),
2758 tool_choice: None,
2759 stop: Vec::new(),
2760 temperature: AgentSettings::temperature_for_model(&model, cx),
2761 };
2762
2763 let available_tools = self.available_tools(cx, model.clone());
2764 let available_tool_names = available_tools
2765 .iter()
2766 .map(|tool| tool.name.clone())
2767 .collect();
2768
2769 let model_context = &ModelContext {
2770 available_tools: available_tool_names,
2771 };
2772
2773 if let Some(project_context) = self.project_context.borrow().as_ref() {
2774 match self
2775 .prompt_builder
2776 .generate_assistant_system_prompt(project_context, model_context)
2777 {
2778 Err(err) => {
2779 let message = format!("{err:?}").into();
2780 log::error!("{message}");
2781 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
2782 header: "Error generating system prompt".into(),
2783 message,
2784 }));
2785 }
2786 Ok(system_prompt) => {
2787 request.messages.push(LanguageModelRequestMessage {
2788 role: Role::System,
2789 content: vec![MessageContent::Text(system_prompt)],
2790 cache: true,
2791 });
2792 }
2793 }
2794 } else {
2795 let message = "Context for system prompt unexpectedly not ready.".into();
2796 log::error!("{message}");
2797 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
2798 header: "Error generating system prompt".into(),
2799 message,
2800 }));
2801 }
2802
2803 let mut message_ix_to_cache = None;
2804 for message in &self.thread_messages {
2805 // ui_only messages are for the UI only, not for the model
2806 if message.ui_only {
2807 continue;
2808 }
2809
2810 let mut request_message = LanguageModelRequestMessage {
2811 role: message.role,
2812 content: Vec::new(),
2813 cache: false,
2814 };
2815
2816 message
2817 .loaded_context
2818 .add_to_request_message(&mut request_message);
2819
2820 for segment in &message.segments {
2821 match segment {
2822 MessageSegment::Text(text) => {
2823 if !text.is_empty() {
2824 request_message
2825 .content
2826 .push(MessageContent::Text(text.into()));
2827 }
2828 }
2829 MessageSegment::Thinking { text, signature } => {
2830 if !text.is_empty() {
2831 request_message.content.push(MessageContent::Thinking {
2832 text: text.into(),
2833 signature: signature.clone(),
2834 });
2835 }
2836 }
2837 MessageSegment::ToolUse { .. } => {
2838 todo!("remove this whole method")
2839 }
2840 };
2841 }
2842
2843 let mut cache_message = true;
2844 let mut tool_results_message = LanguageModelRequestMessage {
2845 role: Role::User,
2846 content: Vec::new(),
2847 cache: false,
2848 };
2849 for (tool_use, tool_result) in self.tool_results(message.id) {
2850 if let Some(tool_result) = tool_result {
2851 request_message
2852 .content
2853 .push(MessageContent::ToolUse(tool_use.clone()));
2854 tool_results_message
2855 .content
2856 .push(MessageContent::ToolResult(LanguageModelToolResult {
2857 tool_use_id: tool_use.id.clone(),
2858 tool_name: tool_result.tool_name.clone(),
2859 is_error: tool_result.is_error,
2860 content: if tool_result.content.is_empty() {
2861 // Surprisingly, the API fails if we return an empty string here.
2862 // It thinks we are sending a tool use without a tool result.
2863 "<Tool returned an empty string>".into()
2864 } else {
2865 tool_result.content.clone()
2866 },
2867 output: None,
2868 }));
2869 } else {
2870 cache_message = false;
2871 log::debug!(
2872 "skipped tool use {:?} because it is still pending",
2873 tool_use
2874 );
2875 }
2876 }
2877
2878 if cache_message {
2879 message_ix_to_cache = Some(request.messages.len());
2880 }
2881 request.messages.push(request_message);
2882
2883 if !tool_results_message.content.is_empty() {
2884 if cache_message {
2885 message_ix_to_cache = Some(request.messages.len());
2886 }
2887 request.messages.push(tool_results_message);
2888 }
2889 }
2890
2891 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
2892 if let Some(message_ix_to_cache) = message_ix_to_cache {
2893 request.messages[message_ix_to_cache].cache = true;
2894 }
2895
2896 request.tools = available_tools;
2897 request.mode = if model.supports_burn_mode() {
2898 Some(self.completion_mode.into())
2899 } else {
2900 Some(CompletionMode::Normal.into())
2901 };
2902
2903 request
2904 }
2905
2906 fn to_summarize_request(
2907 &self,
2908 model: &Arc<dyn LanguageModel>,
2909 intent: CompletionIntent,
2910 added_user_message: String,
2911 cx: &App,
2912 ) -> LanguageModelRequest {
2913 let mut request = LanguageModelRequest {
2914 thread_id: None,
2915 prompt_id: None,
2916 intent: Some(intent),
2917 mode: None,
2918 messages: vec![],
2919 tools: Vec::new(),
2920 tool_choice: None,
2921 stop: Vec::new(),
2922 temperature: AgentSettings::temperature_for_model(model, cx),
2923 };
2924
2925 for message in &self.thread_messages {
2926 let mut request_message = LanguageModelRequestMessage {
2927 role: message.role,
2928 content: Vec::new(),
2929 cache: false,
2930 };
2931
2932 for segment in &message.segments {
2933 match segment {
2934 MessageSegment::Text(text) => request_message
2935 .content
2936 .push(MessageContent::Text(text.clone())),
2937 MessageSegment::Thinking { .. } => {}
2938 MessageSegment::ToolUse { .. } => {}
2939 }
2940 }
2941
2942 if request_message.content.is_empty() {
2943 continue;
2944 }
2945
2946 request.messages.push(request_message);
2947 }
2948
2949 request.messages.push(LanguageModelRequestMessage {
2950 role: Role::User,
2951 content: vec![MessageContent::Text(added_user_message)],
2952 cache: false,
2953 });
2954
2955 request
2956 }
2957
2958 pub fn stream_completion(
2959 &mut self,
2960 request: LanguageModelRequest,
2961 model: Arc<dyn LanguageModel>,
2962 intent: CompletionIntent,
2963 window: Option<AnyWindowHandle>,
2964 cx: &mut Context<Self>,
2965 ) {
2966 self.tool_use_limit_reached = false;
2967
2968 let pending_completion_id = post_inc(&mut self.completion_count);
2969 let mut request_callback_parameters = if self.request_callback.is_some() {
2970 Some((request.clone(), Vec::new()))
2971 } else {
2972 None
2973 };
2974 let prompt_id = self.last_prompt_id.clone();
2975 let tool_use_metadata = ToolUseMetadata {
2976 model: model.clone(),
2977 thread_id: self.id().clone(),
2978 prompt_id: prompt_id.clone(),
2979 };
2980
2981 self.last_received_chunk_at = Some(Instant::now());
2982
2983 let task = cx.spawn(async move |this, cx| {
2984 let stream_completion_future = model.stream_completion(request, &cx);
2985 let initial_token_usage =
2986 this.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
2987 let stream_completion = async {
2988 let mut events = stream_completion_future.await?;
2989
2990 let mut stop_reason = StopReason::EndTurn;
2991 let mut current_token_usage = TokenUsage::default();
2992
2993 this
2994 .update(cx, |_thread, cx| {
2995 cx.emit(ThreadEvent::NewRequest);
2996 })
2997 .ok();
2998
2999 let mut request_assistant_message_id = None;
3000
3001 while let Some(event) = events.next().await {
3002 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
3003 response_events
3004 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
3005 }
3006
3007 this.update(cx, |this, cx| {
3008 let event = match event {
3009 Ok(event) => event,
3010 Err(error) => {
3011 match error {
3012 LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
3013 anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
3014 }
3015 LanguageModelCompletionError::Overloaded => {
3016 anyhow::bail!(LanguageModelKnownError::Overloaded);
3017 }
3018 LanguageModelCompletionError::ApiInternalServerError =>{
3019 anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
3020 }
3021 LanguageModelCompletionError::PromptTooLarge { tokens } => {
3022 let tokens = tokens.unwrap_or_else(|| {
3023 // We didn't get an exact token count from the API, so fall back on our estimate.
3024 this.total_token_usage(cx)
3025 .map(|usage| usage.total)
3026 .unwrap_or(0)
3027 // We know the context window was exceeded in practice, so if our estimate was
3028 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
3029 .max(model.max_token_count().saturating_add(1))
3030 });
3031
3032 anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
3033 }
3034 LanguageModelCompletionError::ApiReadResponseError(io_error) => {
3035 anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
3036 }
3037 LanguageModelCompletionError::UnknownResponseFormat(error) => {
3038 anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
3039 }
3040 LanguageModelCompletionError::HttpResponseError { status, ref body } => {
3041 if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
3042 anyhow::bail!(known_error);
3043 } else {
3044 return Err(error.into());
3045 }
3046 }
3047 LanguageModelCompletionError::DeserializeResponse(error) => {
3048 anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
3049 }
3050 LanguageModelCompletionError::BadInputJson {
3051 id,
3052 tool_name,
3053 raw_input: invalid_input_json,
3054 json_parse_error,
3055 } => {
3056 this.receive_invalid_tool_json(
3057 id,
3058 tool_name,
3059 invalid_input_json,
3060 json_parse_error,
3061 window,
3062 cx,
3063 );
3064 return Ok(());
3065 }
3066 // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
3067 err @ LanguageModelCompletionError::BadRequestFormat |
3068 err @ LanguageModelCompletionError::AuthenticationError |
3069 err @ LanguageModelCompletionError::PermissionError |
3070 err @ LanguageModelCompletionError::ApiEndpointNotFound |
3071 err @ LanguageModelCompletionError::SerializeRequest(_) |
3072 err @ LanguageModelCompletionError::BuildRequestBody(_) |
3073 err @ LanguageModelCompletionError::HttpSend(_) => {
3074 anyhow::bail!(err);
3075 }
3076 LanguageModelCompletionError::Other(error) => {
3077 return Err(error);
3078 }
3079 }
3080 }
3081 };
3082
3083 match event {
3084 LanguageModelCompletionEvent::StartMessage { .. } => {
3085 request_assistant_message_id =
3086 Some(this.insert_assistant_message(
3087 vec![MessageSegment::Text(String::new())],
3088 cx,
3089 ));
3090 }
3091 LanguageModelCompletionEvent::Stop(reason) => {
3092 stop_reason = reason;
3093 }
3094 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
3095 this.update_token_usage_at_last_message(token_usage, cx);
3096 this.cumulative_token_usage = this.cumulative_token_usage
3097 + token_usage
3098 - current_token_usage;
3099 current_token_usage = token_usage;
3100 }
3101 LanguageModelCompletionEvent::Text(chunk) => {
3102 this.received_chunk();
3103
3104 cx.emit(ThreadEvent::ReceivedTextChunk);
3105 if let Some(last_message) = this.thread_messages.last_mut() {
3106 if last_message.role == Role::Assistant
3107 && !this.tool_uses_by_assistant_message.contains_key(&last_message.id)
3108 {
3109 last_message.push_text(chunk.clone());
3110 cx.emit(ThreadEvent::StreamedAssistantText(
3111 last_message.id,
3112 chunk,
3113 ));
3114 } else {
3115 // If we won't have an Assistant message yet, assume this chunk marks the beginning
3116 // of a new Assistant response.
3117 //
3118 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
3119 // will result in duplicating the text of the chunk in the rendered Markdown.
3120 request_assistant_message_id =
3121 Some(this.insert_assistant_message(
3122 vec![MessageSegment::Text(chunk.to_string())],
3123 cx,
3124 ));
3125 };
3126 }
3127 }
3128 LanguageModelCompletionEvent::Thinking {
3129 text: chunk,
3130 signature,
3131 } => {
3132 this.received_chunk();
3133 if let Some(last_message) = this.thread_messages.last_mut() {
3134 if last_message.role == Role::Assistant
3135 && !this.tool_uses_by_assistant_message.contains_key(&last_message.id)
3136 {
3137 last_message.push_thinking(chunk.clone(), signature);
3138 cx.emit(ThreadEvent::StreamedAssistantThinking(
3139 last_message.id,
3140 chunk,
3141 ));
3142 } else {
3143 // If we won't have an Assistant message yet, assume this chunk marks the beginning
3144 // of a new Assistant response.
3145 //
3146 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
3147 // will result in duplicating the text of the chunk in the rendered Markdown.
3148 request_assistant_message_id =
3149 Some(this.insert_assistant_message(
3150 vec![MessageSegment::Thinking {
3151 text: chunk.to_string(),
3152 signature,
3153 }],
3154 cx,
3155 ));
3156 };
3157 }
3158 }
3159 LanguageModelCompletionEvent::RedactedThinking {
3160 ..
3161 } => {
3162 // no more readacted thinking, think in the open
3163 }
3164 LanguageModelCompletionEvent::ToolUse(tool_use) => {
3165 let last_assistant_message_id = request_assistant_message_id
3166 .unwrap_or_else(|| {
3167 let new_assistant_message_id =
3168 this.insert_assistant_message(vec![], cx);
3169 request_assistant_message_id =
3170 Some(new_assistant_message_id);
3171 new_assistant_message_id
3172 });
3173
3174 let tool_use_id = tool_use.id.clone();
3175 let streamed_input = if tool_use.is_input_complete {
3176 None
3177 } else {
3178 Some((&tool_use.input).clone())
3179 };
3180
3181 let ui_text = this.request_tool_use(
3182 last_assistant_message_id,
3183 tool_use,
3184 tool_use_metadata.clone(),
3185 cx,
3186 );
3187
3188 if let Some(input) = streamed_input {
3189 cx.emit(ThreadEvent::StreamedToolUse {
3190 tool_use_id,
3191 ui_text,
3192 input,
3193 });
3194 }
3195 }
3196 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
3197 if let Some(completion) = this
3198 .pending_completions
3199 .iter_mut()
3200 .find(|completion| completion.id == pending_completion_id)
3201 {
3202 match status_update {
3203 CompletionRequestStatus::Queued {
3204 position,
3205 } => {
3206 completion.queue_state = QueueState::Queued { position };
3207 }
3208 CompletionRequestStatus::Started => {
3209 completion.queue_state = QueueState::Started;
3210 }
3211 CompletionRequestStatus::Failed {
3212 code, message, request_id
3213 } => {
3214 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
3215 }
3216 CompletionRequestStatus::UsageUpdated {
3217 amount, limit
3218 } => {
3219 this.update_model_request_usage(amount as u32, limit, cx);
3220 }
3221 CompletionRequestStatus::ToolUseLimitReached => {
3222 this.tool_use_limit_reached = true;
3223 cx.emit(ThreadEvent::ToolUseLimitReached);
3224 }
3225 }
3226 }
3227 }
3228 }
3229
3230 this.touch_updated_at();
3231 cx.emit(ThreadEvent::StreamedCompletion);
3232 cx.notify();
3233
3234 this.auto_capture_telemetry(cx);
3235 Ok(())
3236 })??;
3237
3238 smol::future::yield_now().await;
3239 }
3240
3241 this.update(cx, |this, cx| {
3242 this.last_received_chunk_at.take();
3243 this
3244 .pending_completions
3245 .retain(|completion| completion.id != pending_completion_id);
3246
3247
3248 if matches!(this.summary, ThreadSummary::Pending)
3249 && this.messages().len() >= 2
3250 && (!this.has_pending_tool_uses() || this.messages().len() >= 6)
3251 {
3252 this.summarize(cx);
3253 }
3254 })?;
3255
3256 anyhow::Ok(stop_reason)
3257 };
3258
3259 let result = stream_completion.await;
3260 let mut retry_scheduled = false;
3261
3262 this
3263 .update(cx, |this, cx| {
3264 this.finalize_pending_checkpoint(cx);
3265 match result.as_ref() {
3266 Ok(stop_reason) => {
3267 match stop_reason {
3268 StopReason::ToolUse => {
3269 let tool_uses = this.use_pending_tools(window, model.clone(), cx);
3270 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
3271 }
3272 StopReason::EndTurn | StopReason::MaxTokens => {
3273 this.project.update(cx, |project, cx| {
3274 project.set_agent_location(None, cx);
3275 });
3276 }
3277 StopReason::Refusal => {
3278 this.project.update(cx, |project, cx| {
3279 project.set_agent_location(None, cx);
3280 });
3281
3282 // Remove the turn that was refused.
3283 //
3284 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
3285 {
3286 // todo! move this to turn_loop
3287 // let mut messages_to_remove = Vec::new();
3288
3289 // for (ix, message) in this.thread.read(cx).messages().enumerate().rev() {
3290 // messages_to_remove.push(message.id);
3291
3292 // if message.role == Role::User {
3293 // if ix == 0 {
3294 // break;
3295 // }
3296
3297 // if let Some(prev_message) = this.thread.read(cx).messages.get(ix - 1) {
3298 // if prev_message.role == Role::Assistant {
3299 // break;
3300 // }
3301 // }
3302 // }
3303 // }
3304
3305 // for message_id in messages_to_remove {
3306 // this.delete_message(message_id, cx);
3307 // }
3308 }
3309
3310 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
3311 header: "Language model refusal".into(),
3312 message: "Model refused to generate content for safety reasons.".into(),
3313 }));
3314 }
3315 }
3316
3317 // We successfully completed, so cancel any remaining retries.
3318 this.retry_state = None;
3319 },
3320 Err(error) => {
3321 this.project.update(cx, |project, cx| {
3322 project.set_agent_location(None, cx);
3323 });
3324
3325 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<ZedAgentThread>) {
3326 let error_message = error
3327 .chain()
3328 .map(|err| err.to_string())
3329 .collect::<Vec<_>>()
3330 .join("\n");
3331 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
3332 header: "Error interacting with language model".into(),
3333 message: SharedString::from(error_message.clone()),
3334 }));
3335 }
3336
3337 if error.is::<PaymentRequiredError>() {
3338 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
3339 } else if let Some(error) =
3340 error.downcast_ref::<ModelRequestLimitReachedError>()
3341 {
3342 cx.emit(ThreadEvent::ShowError(
3343 ThreadError::ModelRequestLimitReached { plan: error.plan },
3344 ));
3345 } else if let Some(known_error) =
3346 error.downcast_ref::<LanguageModelKnownError>()
3347 {
3348 match known_error {
3349 LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
3350 this.exceeded_window_error = Some(ExceededWindowError {
3351 model_id: model.id(),
3352 token_count: *tokens,
3353 });
3354 cx.notify();
3355 }
3356 LanguageModelKnownError::RateLimitExceeded { retry_after } => {
3357 let provider_name = model.provider_name();
3358 let error_message = format!(
3359 "{}'s API rate limit exceeded",
3360 provider_name.0.as_ref()
3361 );
3362
3363 this.handle_rate_limit_error(
3364 &error_message,
3365 *retry_after,
3366 model.clone(),
3367 intent,
3368 window,
3369 cx,
3370 );
3371 retry_scheduled = true;
3372 }
3373 LanguageModelKnownError::Overloaded => {
3374 let provider_name = model.provider_name();
3375 let error_message = format!(
3376 "{}'s API servers are overloaded right now",
3377 provider_name.0.as_ref()
3378 );
3379
3380 retry_scheduled = this.handle_retryable_error(
3381 &error_message,
3382 model.clone(),
3383 intent,
3384 window,
3385 cx,
3386 );
3387 if !retry_scheduled {
3388 emit_generic_error(error, cx);
3389 }
3390 }
3391 LanguageModelKnownError::ApiInternalServerError => {
3392 let provider_name = model.provider_name();
3393 let error_message = format!(
3394 "{}'s API server reported an internal server error",
3395 provider_name.0.as_ref()
3396 );
3397
3398 retry_scheduled = this.handle_retryable_error(
3399 &error_message,
3400 model.clone(),
3401 intent,
3402 window,
3403 cx,
3404 );
3405 if !retry_scheduled {
3406 emit_generic_error(error, cx);
3407 }
3408 }
3409 LanguageModelKnownError::ReadResponseError(_) |
3410 LanguageModelKnownError::DeserializeResponse(_) |
3411 LanguageModelKnownError::UnknownResponseFormat(_) => {
3412 // In the future we will attempt to re-roll response, but only once
3413 emit_generic_error(error, cx);
3414 }
3415 }
3416 } else {
3417 emit_generic_error(error, cx);
3418 }
3419
3420 if !retry_scheduled {
3421 this.cancel_last_completion(window, cx);
3422 }
3423 }
3424 }
3425
3426 if !retry_scheduled {
3427 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
3428 }
3429
3430 if let Some((request_callback, (request, response_events))) = this
3431 .request_callback
3432 .as_mut()
3433 .zip(request_callback_parameters.as_ref())
3434 {
3435 request_callback(request, response_events);
3436 }
3437
3438 this.auto_capture_telemetry(cx);
3439
3440 if let Ok(initial_usage) = initial_token_usage {
3441 let usage = this.cumulative_token_usage - initial_usage;
3442
3443 telemetry::event!(
3444 "Assistant Thread Completion",
3445 thread_id = this.id().to_string(),
3446 prompt_id = prompt_id,
3447 model = model.telemetry_id(),
3448 model_provider = model.provider_id().to_string(),
3449 input_tokens = usage.input_tokens,
3450 output_tokens = usage.output_tokens,
3451 cache_creation_input_tokens = usage.cache_creation_input_tokens,
3452 cache_read_input_tokens = usage.cache_read_input_tokens,
3453 );
3454 }
3455 })
3456 .ok();
3457 });
3458
3459 self.pending_completions.push(PendingCompletion {
3460 id: pending_completion_id,
3461 queue_state: QueueState::Sending,
3462 _task: task,
3463 });
3464 }
3465
3466 pub fn summarize(&mut self, cx: &mut Context<Self>) {
3467 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
3468 println!("No thread summary model");
3469 return;
3470 };
3471
3472 if !model.provider.is_authenticated(cx) {
3473 return;
3474 }
3475
3476 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
3477
3478 let request = self.to_summarize_request(
3479 &model.model,
3480 CompletionIntent::ThreadSummarization,
3481 added_user_message.into(),
3482 cx,
3483 );
3484
3485 self.summary = ThreadSummary::Generating;
3486
3487 self.pending_summary = cx.spawn(async move |this, cx| {
3488 let result = async {
3489 let mut messages = model.model.stream_completion(request, &cx).await?;
3490
3491 let mut new_summary = String::new();
3492 while let Some(event) = messages.next().await {
3493 let Ok(event) = event else {
3494 continue;
3495 };
3496 let text = match event {
3497 LanguageModelCompletionEvent::Text(text) => text,
3498 LanguageModelCompletionEvent::StatusUpdate(
3499 CompletionRequestStatus::UsageUpdated { amount, limit },
3500 ) => {
3501 this.update(cx, |thread, cx| {
3502 thread.update_model_request_usage(amount as u32, limit, cx);
3503 })?;
3504 continue;
3505 }
3506 _ => continue,
3507 };
3508
3509 let mut lines = text.lines();
3510 new_summary.extend(lines.next());
3511
3512 // Stop if the LLM generated multiple lines.
3513 if lines.next().is_some() {
3514 break;
3515 }
3516 }
3517
3518 anyhow::Ok(new_summary)
3519 }
3520 .await;
3521
3522 this.update(cx, |thread, cx| {
3523 match result {
3524 Ok(new_summary) => {
3525 if new_summary.is_empty() {
3526 thread.summary = ThreadSummary::Error;
3527 } else {
3528 thread.summary = ThreadSummary::Ready(new_summary.into());
3529 }
3530 }
3531 Err(err) => {
3532 thread.summary = ThreadSummary::Error;
3533 log::error!("Failed to generate thread summary: {}", err);
3534 }
3535 }
3536 cx.emit(ThreadEvent::SummaryGenerated);
3537 })
3538 .log_err()?;
3539
3540 Some(())
3541 });
3542 }
3543
3544 fn handle_rate_limit_error(
3545 &mut self,
3546 error_message: &str,
3547 retry_after: Duration,
3548 model: Arc<dyn LanguageModel>,
3549 intent: CompletionIntent,
3550 window: Option<AnyWindowHandle>,
3551 cx: &mut Context<Self>,
3552 ) {
3553 // For rate limit errors, we only retry once with the specified duration
3554 let retry_message = format!(
3555 "{error_message}. Retrying in {} seconds…",
3556 retry_after.as_secs()
3557 );
3558
3559 self.insert_retry_message(retry_message, cx);
3560 // Schedule the retry
3561 let thread_handle = cx.entity().downgrade();
3562
3563 cx.spawn(async move |_thread, cx| {
3564 cx.background_executor().timer(retry_after).await;
3565
3566 thread_handle
3567 .update(cx, |thread, cx| {
3568 // Retry the completion
3569 thread.send_to_model(model, intent, window, cx);
3570 })
3571 .log_err();
3572 })
3573 .detach();
3574 }
3575
3576 fn handle_retryable_error(
3577 &mut self,
3578 error_message: &str,
3579 model: Arc<dyn LanguageModel>,
3580 intent: CompletionIntent,
3581 window: Option<AnyWindowHandle>,
3582 cx: &mut Context<Self>,
3583 ) -> bool {
3584 self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
3585 }
3586
3587 fn handle_retryable_error_with_delay(
3588 &mut self,
3589 error_message: &str,
3590 custom_delay: Option<Duration>,
3591 model: Arc<dyn LanguageModel>,
3592 intent: CompletionIntent,
3593 window: Option<AnyWindowHandle>,
3594 cx: &mut Context<Self>,
3595 ) -> bool {
3596 let retry_state = self.retry_state.get_or_insert(RetryState {
3597 attempt: 0,
3598 max_attempts: MAX_RETRY_ATTEMPTS,
3599 intent,
3600 });
3601
3602 retry_state.attempt += 1;
3603 let attempt = retry_state.attempt;
3604 let max_attempts = retry_state.max_attempts;
3605 let intent = retry_state.intent;
3606
3607 if attempt <= max_attempts {
3608 // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
3609 let delay = if let Some(custom_delay) = custom_delay {
3610 custom_delay
3611 } else {
3612 BASE_RETRY_DELAY * 2u32.pow((attempt - 1) as u32)
3613 };
3614
3615 // Add a transient message to inform the user
3616 let delay_secs = delay.as_secs();
3617 let retry_message = format!(
3618 "{}. Retrying (attempt {} of {}) in {} seconds...",
3619 error_message, attempt, max_attempts, delay_secs
3620 );
3621
3622 self.insert_retry_message(retry_message, cx);
3623
3624 // Schedule the retry
3625 let thread_handle = cx.entity().downgrade();
3626
3627 cx.spawn(async move |_thread, cx| {
3628 cx.background_executor().timer(delay).await;
3629
3630 thread_handle
3631 .update(cx, |thread, cx| {
3632 // Retry the completion
3633 thread.send_to_model(model, intent, window, cx);
3634 })
3635 .log_err();
3636 })
3637 .detach();
3638
3639 true
3640 } else {
3641 // Max retries exceeded
3642 self.retry_state = None;
3643
3644 let notification_text = if max_attempts == 1 {
3645 "Failed after retrying.".into()
3646 } else {
3647 format!("Failed after retrying {} times.", max_attempts).into()
3648 };
3649
3650 // Stop generating since we're giving up on retrying.
3651 self.pending_completions.clear();
3652
3653 cx.emit(ThreadEvent::RetriesFailed {
3654 message: notification_text,
3655 });
3656
3657 false
3658 }
3659 }
3660
3661 pub fn start_generating_detailed_summary_if_needed(
3662 &mut self,
3663 thread_store: WeakEntity<ThreadStore>,
3664 cx: &mut Context<Self>,
3665 ) {
3666 let Some(last_message_id) = self.thread_messages.last().map(|message| message.id) else {
3667 return;
3668 };
3669
3670 match &*self.detailed_summary_rx.borrow() {
3671 DetailedSummaryState::Generating { message_id, .. }
3672 | DetailedSummaryState::Generated { message_id, .. }
3673 if *message_id == last_message_id =>
3674 {
3675 // Already up-to-date
3676 return;
3677 }
3678 _ => {}
3679 }
3680
3681 let Some(ConfiguredModel { model, provider }) =
3682 LanguageModelRegistry::read_global(cx).thread_summary_model()
3683 else {
3684 return;
3685 };
3686
3687 if !provider.is_authenticated(cx) {
3688 return;
3689 }
3690
3691 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
3692
3693 let request = self.to_summarize_request(
3694 &model,
3695 CompletionIntent::ThreadContextSummarization,
3696 added_user_message.into(),
3697 cx,
3698 );
3699
3700 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
3701 message_id: last_message_id,
3702 };
3703
3704 // Replace the detailed summarization task if there is one, cancelling it. It would probably
3705 // be better to allow the old task to complete, but this would require logic for choosing
3706 // which result to prefer (the old task could complete after the new one, resulting in a
3707 // stale summary).
3708 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
3709 let stream = model.stream_completion_text(request, &cx);
3710 let Some(mut messages) = stream.await.log_err() else {
3711 thread
3712 .update(cx, |thread, _cx| {
3713 *thread.detailed_summary_tx.borrow_mut() =
3714 DetailedSummaryState::NotGenerated;
3715 })
3716 .ok()?;
3717 return None;
3718 };
3719
3720 let mut new_detailed_summary = String::new();
3721
3722 while let Some(chunk) = messages.stream.next().await {
3723 if let Some(chunk) = chunk.log_err() {
3724 new_detailed_summary.push_str(&chunk);
3725 }
3726 }
3727
3728 thread
3729 .update(cx, |thread, _cx| {
3730 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
3731 text: new_detailed_summary.into(),
3732 message_id: last_message_id,
3733 };
3734 })
3735 .ok()?;
3736
3737 // Save thread so its summary can be reused later
3738 if let Some(thread) = thread.upgrade() {
3739 if let Ok(Ok(save_task)) = cx.update(|cx| {
3740 thread_store
3741 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
3742 }) {
3743 save_task.await.log_err();
3744 }
3745 }
3746
3747 Some(())
3748 });
3749 }
3750
3751 pub async fn wait_for_detailed_summary_or_text(
3752 this: &Entity<Self>,
3753 cx: &mut AsyncApp,
3754 ) -> Option<SharedString> {
3755 let mut detailed_summary_rx = this
3756 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
3757 .ok()?;
3758 loop {
3759 match detailed_summary_rx.recv().await? {
3760 DetailedSummaryState::Generating { .. } => {}
3761 DetailedSummaryState::NotGenerated => {
3762 return this.read_with(cx, |this, _cx| this.text().into()).ok();
3763 }
3764 DetailedSummaryState::Generated { text, .. } => return Some(text),
3765 }
3766 }
3767 }
3768
3769 pub fn latest_detailed_summary_or_text(&self, _cx: &App) -> SharedString {
3770 self.detailed_summary_rx
3771 .borrow()
3772 .text()
3773 .unwrap_or_else(|| self.text().into())
3774 }
3775
3776 pub fn is_generating_detailed_summary(&self) -> bool {
3777 matches!(
3778 &*self.detailed_summary_rx.borrow(),
3779 DetailedSummaryState::Generating { .. }
3780 )
3781 }
3782
3783 pub fn use_pending_tools(
3784 &mut self,
3785 window: Option<AnyWindowHandle>,
3786 model: Arc<dyn LanguageModel>,
3787 cx: &mut Context<Self>,
3788 ) -> Vec<PendingToolUse> {
3789 self.auto_capture_telemetry(cx);
3790 let request =
3791 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
3792 let pending_tool_uses = self
3793 .pending_tool_uses_by_id
3794 .values()
3795 .filter(|tool_use| tool_use.status.is_idle())
3796 .cloned()
3797 .collect::<Vec<_>>();
3798
3799 for tool_use in pending_tool_uses.iter() {
3800 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
3801 }
3802
3803 pending_tool_uses
3804 }
3805
3806 fn use_pending_tool(
3807 &mut self,
3808 tool_use: PendingToolUse,
3809 request: Arc<LanguageModelRequest>,
3810 model: Arc<dyn LanguageModel>,
3811 window: Option<AnyWindowHandle>,
3812 cx: &mut Context<Self>,
3813 ) {
3814 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
3815 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
3816 };
3817
3818 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
3819 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
3820 }
3821
3822 if tool.needs_confirmation(&tool_use.input, cx)
3823 && !AgentSettings::get_global(cx).always_allow_tool_actions
3824 {
3825 self.confirm_tool_use(tool_use.id, tool_use.ui_text, tool_use.input, request, tool);
3826 cx.emit(ThreadEvent::ToolConfirmationNeeded);
3827 } else {
3828 self.run_tool(
3829 tool_use.id,
3830 tool_use.ui_text,
3831 tool_use.input,
3832 request,
3833 tool,
3834 model,
3835 window,
3836 cx,
3837 );
3838 }
3839 }
3840
3841 pub fn request_tool_use(
3842 &mut self,
3843 assistant_message_id: MessageId,
3844 tool_use: LanguageModelToolUse,
3845 metadata: ToolUseMetadata,
3846 cx: &App,
3847 ) -> Arc<str> {
3848 let tool_uses = self
3849 .tool_uses_by_assistant_message
3850 .entry(assistant_message_id)
3851 .or_default();
3852
3853 let mut existing_tool_use_found = false;
3854
3855 for existing_tool_use in tool_uses.iter_mut() {
3856 if existing_tool_use.id == tool_use.id {
3857 *existing_tool_use = tool_use.clone();
3858 existing_tool_use_found = true;
3859 }
3860 }
3861
3862 if !existing_tool_use_found {
3863 tool_uses.push(tool_use.clone());
3864 }
3865
3866 let status = if tool_use.is_input_complete {
3867 self.tool_use_metadata_by_id
3868 .insert(tool_use.id.clone(), metadata);
3869
3870 PendingToolUseStatus::Idle
3871 } else {
3872 PendingToolUseStatus::InputStillStreaming
3873 };
3874
3875 let ui_text: Arc<str> = self
3876 .tool_ui_label(
3877 &tool_use.name,
3878 &tool_use.input,
3879 tool_use.is_input_complete,
3880 cx,
3881 )
3882 .into();
3883
3884 let may_perform_edits = self
3885 .tools
3886 .read(cx)
3887 .tool(&tool_use.name, cx)
3888 .is_some_and(|tool| tool.may_perform_edits());
3889
3890 self.pending_tool_uses_by_id.insert(
3891 tool_use.id.clone(),
3892 PendingToolUse {
3893 assistant_message_id,
3894 id: tool_use.id,
3895 name: tool_use.name.clone(),
3896 ui_text: ui_text.clone(),
3897 input: tool_use.input,
3898 may_perform_edits,
3899 status,
3900 },
3901 );
3902
3903 ui_text
3904 }
3905
3906 pub fn tool_ui_label(
3907 &self,
3908 tool_name: &str,
3909 input: &serde_json::Value,
3910 is_input_complete: bool,
3911 cx: &App,
3912 ) -> SharedString {
3913 if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
3914 if is_input_complete {
3915 tool.ui_text(input).into()
3916 } else {
3917 tool.still_streaming_ui_text(input).into()
3918 }
3919 } else {
3920 format!("Unknown tool {tool_name:?}").into()
3921 }
3922 }
3923
3924 fn confirm_tool_use(
3925 &mut self,
3926 tool_use_id: LanguageModelToolUseId,
3927 ui_text: impl Into<Arc<str>>,
3928 input: serde_json::Value,
3929 request: Arc<LanguageModelRequest>,
3930 tool: Arc<dyn Tool>,
3931 ) {
3932 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
3933 let ui_text = ui_text.into();
3934 tool_use.ui_text = ui_text.clone();
3935 let confirmation = Confirmation {
3936 tool_use_id,
3937 input,
3938 request,
3939 tool,
3940 ui_text,
3941 };
3942 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
3943 }
3944 }
3945
3946 pub fn tool_results(
3947 &self,
3948 assistant_message_id: MessageId,
3949 ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
3950 self.tool_uses_by_assistant_message
3951 .get(&assistant_message_id)
3952 .into_iter()
3953 .flatten()
3954 .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
3955 }
3956
3957 pub fn handle_hallucinated_tool_use(
3958 &mut self,
3959 tool_use_id: LanguageModelToolUseId,
3960 hallucinated_tool_name: Arc<str>,
3961 window: Option<AnyWindowHandle>,
3962 cx: &mut Context<ZedAgentThread>,
3963 ) {
3964 let available_tools = self.profile.enabled_tools(cx);
3965
3966 let tool_list = available_tools
3967 .iter()
3968 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
3969 .collect::<Vec<_>>()
3970 .join("\n");
3971
3972 let error_message = format!(
3973 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
3974 hallucinated_tool_name, tool_list
3975 );
3976
3977 let pending_tool_use = self.insert_tool_output(
3978 tool_use_id.clone(),
3979 hallucinated_tool_name,
3980 Err(anyhow!("Missing tool call: {error_message}")),
3981 );
3982
3983 cx.emit(ThreadEvent::MissingToolUse {
3984 tool_use_id: tool_use_id.clone(),
3985 ui_text: error_message.into(),
3986 });
3987
3988 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
3989 }
3990
3991 pub fn receive_invalid_tool_json(
3992 &mut self,
3993 tool_use_id: LanguageModelToolUseId,
3994 tool_name: Arc<str>,
3995 invalid_json: Arc<str>,
3996 error: String,
3997 window: Option<AnyWindowHandle>,
3998 cx: &mut Context<ZedAgentThread>,
3999 ) {
4000 log::error!("The model returned invalid input JSON: {invalid_json}");
4001
4002 let pending_tool_use = self.insert_tool_output(
4003 tool_use_id.clone(),
4004 tool_name,
4005 Err(anyhow!("Error parsing input JSON: {error}")),
4006 );
4007 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
4008 pending_tool_use.ui_text.clone()
4009 } else {
4010 log::error!(
4011 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
4012 );
4013 format!("Unknown tool {}", tool_use_id).into()
4014 };
4015
4016 cx.emit(ThreadEvent::InvalidToolInput {
4017 tool_use_id: tool_use_id.clone(),
4018 ui_text,
4019 invalid_input_json: invalid_json,
4020 });
4021
4022 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
4023 }
4024
4025 fn tool_for_name(&self, name: &str, cx: &App) -> Result<Arc<dyn Tool>> {
4026 if let Some(tool) = self.tools.read(cx).tool(name, cx)
4027 && self.profile.is_tool_enabled(tool.source(), tool.name(), cx)
4028 {
4029 Ok(tool)
4030 } else {
4031 let available_tools = self.profile.enabled_tools(cx);
4032 let tool_list = available_tools
4033 .iter()
4034 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
4035 .collect::<Vec<_>>()
4036 .join("\n");
4037 let error_message = format!(
4038 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
4039 name, tool_list
4040 );
4041 Err(anyhow!(error_message))
4042 }
4043 }
4044
4045 pub fn run_tool(
4046 &mut self,
4047 tool_use_id: LanguageModelToolUseId,
4048 ui_text: impl Into<Arc<str>>,
4049 input: serde_json::Value,
4050 request: Arc<LanguageModelRequest>,
4051 tool: Arc<dyn Tool>,
4052 model: Arc<dyn LanguageModel>,
4053 window: Option<AnyWindowHandle>,
4054 cx: &mut Context<ZedAgentThread>,
4055 ) {
4056 let task =
4057 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
4058 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
4059 tool_use.ui_text = ui_text.into();
4060 tool_use.status = PendingToolUseStatus::Running {
4061 _task: task.shared(),
4062 };
4063 }
4064 }
4065
4066 fn spawn_tool_use(
4067 &mut self,
4068 tool_use_id: LanguageModelToolUseId,
4069 request: Arc<LanguageModelRequest>,
4070 input: serde_json::Value,
4071 tool: Arc<dyn Tool>,
4072 model: Arc<dyn LanguageModel>,
4073 window: Option<AnyWindowHandle>,
4074 cx: &mut Context<ZedAgentThread>,
4075 ) -> Task<()> {
4076 let tool_name: Arc<str> = tool.name().into();
4077
4078 let tool_result = tool.run(
4079 input,
4080 request,
4081 self.project.clone(),
4082 self.action_log(),
4083 model,
4084 window,
4085 cx,
4086 );
4087
4088 // Store the card separately if it exists
4089 if let Some(card) = tool_result.card.clone() {
4090 self.tool_result_cards.insert(tool_use_id.clone(), card);
4091 }
4092
4093 cx.spawn({
4094 async move |thread: WeakEntity<ZedAgentThread>, cx| {
4095 let output = tool_result.output.await;
4096
4097 thread
4098 .update(cx, |thread, cx| {
4099 let pending_tool_use =
4100 thread.insert_tool_output(tool_use_id.clone(), tool_name, output);
4101 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
4102 })
4103 .ok();
4104 }
4105 })
4106 }
4107
4108 fn tool_finished(
4109 &mut self,
4110 tool_use_id: LanguageModelToolUseId,
4111 pending_tool_use: Option<PendingToolUse>,
4112 canceled: bool,
4113 window: Option<AnyWindowHandle>,
4114 cx: &mut Context<Self>,
4115 ) {
4116 if self.all_tools_finished() {
4117 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
4118 if !canceled {
4119 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
4120 }
4121 self.auto_capture_telemetry(cx);
4122 }
4123 }
4124
4125 cx.emit(ThreadEvent::ToolFinished {
4126 tool_use_id,
4127 pending_tool_use,
4128 });
4129 }
4130
4131 /// Cancels the last pending completion, if there are any pending.
4132 ///
4133 /// Returns whether a completion was canceled.
4134 pub fn cancel_last_completion(
4135 &mut self,
4136 window: Option<AnyWindowHandle>,
4137 cx: &mut Context<Self>,
4138 ) -> bool {
4139 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
4140
4141 self.retry_state = None;
4142
4143 for pending_tool_use in self.cancel_pending() {
4144 canceled = true;
4145 self.tool_finished(
4146 pending_tool_use.id.clone(),
4147 Some(pending_tool_use),
4148 true,
4149 window,
4150 cx,
4151 );
4152 }
4153
4154 if canceled {
4155 cx.emit(ThreadEvent::CompletionCanceled);
4156
4157 // When canceled, we always want to insert the checkpoint.
4158 // (We skip over finalize_pending_checkpoint, because it
4159 // would conclude we didn't have anything to insert here.)
4160 if let Some(checkpoint) = self.pending_checkpoint.take() {
4161 self.insert_checkpoint(checkpoint, cx);
4162 }
4163 } else {
4164 self.finalize_pending_checkpoint(cx);
4165 }
4166
4167 canceled
4168 }
4169
4170 fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
4171 let mut cancelled_tool_uses = Vec::new();
4172 self.pending_tool_uses_by_id
4173 .retain(|tool_use_id, tool_use| {
4174 if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
4175 return true;
4176 }
4177
4178 let content = "Tool canceled by user".into();
4179 self.tool_results.insert(
4180 tool_use_id.clone(),
4181 LanguageModelToolResult {
4182 tool_use_id: tool_use_id.clone(),
4183 tool_name: tool_use.name.clone(),
4184 content,
4185 output: None,
4186 is_error: true,
4187 },
4188 );
4189 cancelled_tool_uses.push(tool_use.clone());
4190 false
4191 });
4192 cancelled_tool_uses
4193 }
4194
4195 /// Signals that any in-progress editing should be canceled.
4196 ///
4197 /// This method is used to notify listeners (like ActiveThread) that
4198 /// they should cancel any editing operations.
4199 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
4200 cx.emit(ThreadEvent::CancelEditing);
4201 }
4202
4203 pub fn feedback(&self) -> Option<ThreadFeedback> {
4204 self.feedback
4205 }
4206
4207 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
4208 self.message_feedback.get(&message_id).copied()
4209 }
4210
4211 pub fn report_message_feedback(
4212 &mut self,
4213 message_id: MessageId,
4214 feedback: ThreadFeedback,
4215 cx: &mut Context<Self>,
4216 ) -> Task<Result<()>> {
4217 if self.message_feedback.get(&message_id) == Some(&feedback) {
4218 return Task::ready(Ok(()));
4219 }
4220
4221 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
4222 let serialized_thread = self.serialize(cx);
4223 let thread_id = self.id().clone();
4224 let client = self.project.read(cx).client();
4225
4226 let enabled_tool_names: Vec<String> = self
4227 .profile
4228 .enabled_tools(cx)
4229 .iter()
4230 .map(|tool| tool.name())
4231 .collect();
4232
4233 self.message_feedback.insert(message_id, feedback);
4234
4235 cx.notify();
4236
4237 let message_content = self
4238 .message(message_id)
4239 .map(|msg| msg.to_string())
4240 .unwrap_or_default();
4241
4242 cx.background_spawn(async move {
4243 let final_project_snapshot = final_project_snapshot.await;
4244 let serialized_thread = serialized_thread.await?;
4245 let thread_data =
4246 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
4247
4248 let rating = match feedback {
4249 ThreadFeedback::Positive => "positive",
4250 ThreadFeedback::Negative => "negative",
4251 };
4252 telemetry::event!(
4253 "Assistant Thread Rated",
4254 rating,
4255 thread_id,
4256 enabled_tool_names,
4257 message_id = message_id.0,
4258 message_content,
4259 thread_data,
4260 final_project_snapshot
4261 );
4262 client.telemetry().flush_events().await;
4263
4264 Ok(())
4265 })
4266 }
4267
4268 pub fn report_feedback(
4269 &mut self,
4270 feedback: ThreadFeedback,
4271 cx: &mut Context<Self>,
4272 ) -> Task<Result<()>> {
4273 let last_assistant_message_id = self
4274 .messages()
4275 .rev()
4276 .find(|msg| msg.role == Role::Assistant)
4277 .map(|msg| msg.id);
4278
4279 if let Some(message_id) = last_assistant_message_id {
4280 self.report_message_feedback(message_id, feedback, cx)
4281 } else {
4282 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
4283 let serialized_thread = self.serialize(cx);
4284 let thread_id = self.id().clone();
4285 let client = self.project.read(cx).client();
4286 self.feedback = Some(feedback);
4287 cx.notify();
4288
4289 cx.background_spawn(async move {
4290 let final_project_snapshot = final_project_snapshot.await;
4291 let serialized_thread = serialized_thread.await?;
4292 let thread_data = serde_json::to_value(serialized_thread)
4293 .unwrap_or_else(|_| serde_json::Value::Null);
4294
4295 let rating = match feedback {
4296 ThreadFeedback::Positive => "positive",
4297 ThreadFeedback::Negative => "negative",
4298 };
4299 telemetry::event!(
4300 "Assistant Thread Rated",
4301 rating,
4302 thread_id,
4303 thread_data,
4304 final_project_snapshot
4305 );
4306 client.telemetry().flush_events().await;
4307
4308 Ok(())
4309 })
4310 }
4311 }
4312
4313 /// Create a snapshot of the current project state including git information and unsaved buffers.
4314 fn project_snapshot(
4315 project: Entity<Project>,
4316 cx: &mut Context<Self>,
4317 ) -> Task<Arc<ProjectSnapshot>> {
4318 let git_store = project.read(cx).git_store().clone();
4319 let worktree_snapshots: Vec<_> = project
4320 .read(cx)
4321 .visible_worktrees(cx)
4322 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
4323 .collect();
4324
4325 cx.spawn(async move |_, cx| {
4326 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
4327
4328 let mut unsaved_buffers = Vec::new();
4329 cx.update(|app_cx| {
4330 let buffer_store = project.read(app_cx).buffer_store();
4331 for buffer_handle in buffer_store.read(app_cx).buffers() {
4332 let buffer = buffer_handle.read(app_cx);
4333 if buffer.is_dirty() {
4334 if let Some(file) = buffer.file() {
4335 let path = file.path().to_string_lossy().to_string();
4336 unsaved_buffers.push(path);
4337 }
4338 }
4339 }
4340 })
4341 .ok();
4342
4343 Arc::new(ProjectSnapshot {
4344 worktree_snapshots,
4345 unsaved_buffer_paths: unsaved_buffers,
4346 timestamp: Utc::now(),
4347 })
4348 })
4349 }
4350
4351 fn worktree_snapshot(
4352 worktree: Entity<project::Worktree>,
4353 git_store: Entity<GitStore>,
4354 cx: &App,
4355 ) -> Task<WorktreeSnapshot> {
4356 cx.spawn(async move |cx| {
4357 // Get worktree path and snapshot
4358 let worktree_info = cx.update(|app_cx| {
4359 let worktree = worktree.read(app_cx);
4360 let path = worktree.abs_path().to_string_lossy().to_string();
4361 let snapshot = worktree.snapshot();
4362 (path, snapshot)
4363 });
4364
4365 let Ok((worktree_path, _snapshot)) = worktree_info else {
4366 return WorktreeSnapshot {
4367 worktree_path: String::new(),
4368 git_state: None,
4369 };
4370 };
4371
4372 let git_state = git_store
4373 .update(cx, |git_store, cx| {
4374 git_store
4375 .repositories()
4376 .values()
4377 .find(|repo| {
4378 repo.read(cx)
4379 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
4380 .is_some()
4381 })
4382 .cloned()
4383 })
4384 .ok()
4385 .flatten()
4386 .map(|repo| {
4387 repo.update(cx, |repo, _| {
4388 let current_branch =
4389 repo.branch.as_ref().map(|branch| branch.name().to_owned());
4390 repo.send_job(None, |state, _| async move {
4391 let RepositoryState::Local { backend, .. } = state else {
4392 return GitState {
4393 remote_url: None,
4394 head_sha: None,
4395 current_branch,
4396 diff: None,
4397 };
4398 };
4399
4400 let remote_url = backend.remote_url("origin");
4401 let head_sha = backend.head_sha().await;
4402 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
4403
4404 GitState {
4405 remote_url,
4406 head_sha,
4407 current_branch,
4408 diff,
4409 }
4410 })
4411 })
4412 });
4413
4414 let git_state = match git_state {
4415 Some(git_state) => match git_state.ok() {
4416 Some(git_state) => git_state.await.ok(),
4417 None => None,
4418 },
4419 None => None,
4420 };
4421
4422 WorktreeSnapshot {
4423 worktree_path,
4424 git_state,
4425 }
4426 })
4427 }
4428
4429 pub fn to_markdown(&self, cx: &App) -> Result<String> {
4430 use std::io::Write;
4431 let mut markdown = Vec::new();
4432
4433 let summary = self.summary().or_default();
4434 writeln!(markdown, "# {summary}\n")?;
4435
4436 for message in self.messages() {
4437 writeln!(
4438 markdown,
4439 "## {role}\n",
4440 role = match message.role {
4441 Role::User => "User",
4442 Role::Assistant => "Agent",
4443 Role::System => "System",
4444 }
4445 )?;
4446
4447 if !message.loaded_context.text.is_empty() {
4448 writeln!(markdown, "{}", message.loaded_context.text)?;
4449 }
4450
4451 if !message.loaded_context.images.is_empty() {
4452 writeln!(
4453 markdown,
4454 "\n{} images attached as context.\n",
4455 message.loaded_context.images.len()
4456 )?;
4457 }
4458
4459 for segment in &message.segments {
4460 match segment {
4461 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
4462 MessageSegment::Thinking { text, .. } => {
4463 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
4464 }
4465 MessageSegment::ToolUse { .. } => {}
4466 }
4467 }
4468
4469 for tool_use in self.tool_uses_for_message(message.id, cx) {
4470 writeln!(
4471 markdown,
4472 "**Use Tool: {} ({})**",
4473 tool_use.name, tool_use.id
4474 )?;
4475 writeln!(markdown, "```json")?;
4476 writeln!(
4477 markdown,
4478 "{}",
4479 serde_json::to_string_pretty(&tool_use.input)?
4480 )?;
4481 writeln!(markdown, "```")?;
4482 }
4483
4484 for tool_result in self.tool_results_for_message(message.id) {
4485 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
4486 if tool_result.is_error {
4487 write!(markdown, " (Error)")?;
4488 }
4489
4490 writeln!(markdown, "**\n")?;
4491 match &tool_result.content {
4492 LanguageModelToolResultContent::Text(text) => {
4493 writeln!(markdown, "{text}")?;
4494 }
4495 LanguageModelToolResultContent::Image(image) => {
4496 writeln!(markdown, "", image.source)?;
4497 }
4498 }
4499
4500 if let Some(output) = tool_result.output.as_ref() {
4501 writeln!(
4502 markdown,
4503 "\n\nDebug Output:\n\n```json\n{}\n```\n",
4504 serde_json::to_string_pretty(output)?
4505 )?;
4506 }
4507 }
4508 }
4509
4510 Ok(String::from_utf8_lossy(&markdown).to_string())
4511 }
4512
4513 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
4514 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
4515 return;
4516 }
4517
4518 let now = Instant::now();
4519 if let Some(last) = self.last_auto_capture_at {
4520 if now.duration_since(last).as_secs() < 10 {
4521 return;
4522 }
4523 }
4524
4525 self.last_auto_capture_at = Some(now);
4526
4527 let thread_id = self.id().clone();
4528 let github_login = self
4529 .project
4530 .read(cx)
4531 .user_store()
4532 .read(cx)
4533 .current_user()
4534 .map(|user| user.github_login.clone());
4535 let client = self.project.read(cx).client();
4536 let serialize_task = self.serialize(cx);
4537
4538 cx.background_executor()
4539 .spawn(async move {
4540 if let Ok(serialized_thread) = serialize_task.await {
4541 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
4542 telemetry::event!(
4543 "Agent Thread Auto-Captured",
4544 thread_id = thread_id.to_string(),
4545 thread_data = thread_data,
4546 auto_capture_reason = "tracked_user",
4547 github_login = github_login
4548 );
4549
4550 client.telemetry().flush_events().await;
4551 }
4552 }
4553 })
4554 .detach();
4555 }
4556
4557 pub fn cumulative_token_usage(&self) -> TokenUsage {
4558 self.cumulative_token_usage
4559 }
4560
4561 pub fn token_usage_up_to_message(&self, message_id: MessageId, _cx: &App) -> TotalTokenUsage {
4562 let Some(model) = self.configured_model.as_ref() else {
4563 return TotalTokenUsage::default();
4564 };
4565
4566 let max = model.model.max_token_count();
4567
4568 let index = self
4569 .messages()
4570 .position(|msg| msg.id == message_id)
4571 .unwrap_or(0);
4572
4573 if index == 0 {
4574 return TotalTokenUsage { total: 0, max };
4575 }
4576
4577 let token_usage = &self
4578 .request_token_usage
4579 .get(index - 1)
4580 .cloned()
4581 .unwrap_or_default();
4582
4583 TotalTokenUsage {
4584 total: token_usage.total_tokens(),
4585 max,
4586 }
4587 }
4588
4589 pub fn total_token_usage(&self, cx: &App) -> Option<TotalTokenUsage> {
4590 let model = self.configured_model.as_ref()?;
4591
4592 let max = model.model.max_token_count();
4593
4594 if let Some(exceeded_error) = &self.exceeded_window_error {
4595 if model.model.id() == exceeded_error.model_id {
4596 return Some(TotalTokenUsage {
4597 total: exceeded_error.token_count,
4598 max,
4599 });
4600 }
4601 }
4602
4603 let total = self
4604 .token_usage_at_last_message(cx)
4605 .unwrap_or_default()
4606 .total_tokens();
4607
4608 Some(TotalTokenUsage { total, max })
4609 }
4610
4611 fn token_usage_at_last_message(&self, _cx: &App) -> Option<TokenUsage> {
4612 self.request_token_usage
4613 .get(self.messages().len().saturating_sub(1))
4614 .or_else(|| self.request_token_usage.last())
4615 .cloned()
4616 }
4617
4618 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage, cx: &App) {
4619 let placeholder = self.token_usage_at_last_message(cx).unwrap_or_default();
4620 let len = self.messages().len();
4621 self.request_token_usage.resize(len, placeholder);
4622
4623 if let Some(last) = self.request_token_usage.last_mut() {
4624 *last = token_usage;
4625 }
4626 }
4627
4628 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
4629 self.project.update(cx, |project, cx| {
4630 project.user_store().update(cx, |user_store, cx| {
4631 user_store.update_model_request_usage(
4632 ModelRequestUsage(RequestUsage {
4633 amount: amount as i32,
4634 limit,
4635 }),
4636 cx,
4637 )
4638 })
4639 });
4640 }
4641
4642 pub fn deny_tool_use(
4643 &mut self,
4644 tool_use_id: LanguageModelToolUseId,
4645 tool_name: Arc<str>,
4646 window: Option<AnyWindowHandle>,
4647 cx: &mut Context<Self>,
4648 ) {
4649 let err = Err(anyhow::anyhow!(
4650 "Permission to run tool action denied by user"
4651 ));
4652
4653 self.insert_tool_output(tool_use_id.clone(), tool_name, err);
4654 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
4655 }
4656
4657 pub fn insert_tool_output(
4658 &mut self,
4659 tool_use_id: LanguageModelToolUseId,
4660 tool_name: Arc<str>,
4661 output: Result<ToolResultOutput>,
4662 ) -> Option<PendingToolUse> {
4663 let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
4664
4665 telemetry::event!(
4666 "Agent Tool Finished",
4667 model = metadata
4668 .as_ref()
4669 .map(|metadata| metadata.model.telemetry_id()),
4670 model_provider = metadata
4671 .as_ref()
4672 .map(|metadata| metadata.model.provider_id().to_string()),
4673 thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
4674 prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
4675 tool_name,
4676 success = output.is_ok()
4677 );
4678
4679 match output {
4680 Ok(output) => {
4681 let tool_result = output.content;
4682 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
4683
4684 let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
4685
4686 // Protect from overly large output
4687 let tool_output_limit = self
4688 .configured_model
4689 .as_ref()
4690 .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
4691 .unwrap_or(usize::MAX);
4692
4693 let content = match tool_result {
4694 ToolResultContent::Text(text) => {
4695 let text = if text.len() < tool_output_limit {
4696 text
4697 } else {
4698 let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
4699 format!(
4700 "Tool result too long. The first {} bytes:\n\n{}",
4701 truncated.len(),
4702 truncated
4703 )
4704 };
4705 LanguageModelToolResultContent::Text(text.into())
4706 }
4707 ToolResultContent::Image(language_model_image) => {
4708 if language_model_image.estimate_tokens() < tool_output_limit {
4709 LanguageModelToolResultContent::Image(language_model_image)
4710 } else {
4711 self.tool_results.insert(
4712 tool_use_id.clone(),
4713 LanguageModelToolResult {
4714 tool_use_id: tool_use_id.clone(),
4715 tool_name,
4716 content: "Tool responded with an image that would exceeded the remaining tokens".into(),
4717 is_error: true,
4718 output: None,
4719 },
4720 );
4721
4722 return old_use;
4723 }
4724 }
4725 };
4726
4727 self.tool_results.insert(
4728 tool_use_id.clone(),
4729 LanguageModelToolResult {
4730 tool_use_id: tool_use_id.clone(),
4731 tool_name,
4732 content,
4733 is_error: false,
4734 output: output.output,
4735 },
4736 );
4737
4738 old_use
4739 }
4740 Err(err) => {
4741 self.tool_results.insert(
4742 tool_use_id.clone(),
4743 LanguageModelToolResult {
4744 tool_use_id: tool_use_id.clone(),
4745 tool_name,
4746 content: LanguageModelToolResultContent::Text(err.to_string().into()),
4747 is_error: true,
4748 output: None,
4749 },
4750 );
4751
4752 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
4753 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
4754 }
4755
4756 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
4757 }
4758 }
4759 }
4760}
4761
4762#[derive(Debug, Clone, Error)]
4763pub enum ThreadError {
4764 #[error("Payment required")]
4765 PaymentRequired,
4766 #[error("Model request limit reached")]
4767 ModelRequestLimitReached { plan: Plan },
4768 #[error("Message {header}: {message}")]
4769 Message {
4770 header: SharedString,
4771 message: SharedString,
4772 },
4773}
4774
4775#[derive(Debug, Clone)]
4776pub enum ThreadEvent {
4777 ShowError(ThreadError),
4778 StreamedCompletion,
4779 ReceivedTextChunk,
4780 NewRequest,
4781 StreamedAssistantText(MessageId, String),
4782 StreamedAssistantThinking(MessageId, String),
4783 StreamedToolUse {
4784 tool_use_id: LanguageModelToolUseId,
4785 ui_text: Arc<str>,
4786 input: serde_json::Value,
4787 },
4788 StreamedToolUse2 {
4789 message_id: MessageId,
4790 segment_index: usize,
4791 },
4792 MissingToolUse {
4793 tool_use_id: LanguageModelToolUseId,
4794 ui_text: Arc<str>,
4795 },
4796 InvalidToolInput {
4797 tool_use_id: LanguageModelToolUseId,
4798 ui_text: Arc<str>,
4799 invalid_input_json: Arc<str>,
4800 },
4801 Stopped(Result<StopReason, Arc<anyhow::Error>>),
4802 MessageAdded(MessageId),
4803 MessageEdited(MessageId),
4804 MessageDeleted(MessageId),
4805 SummaryGenerated,
4806 SummaryChanged,
4807 UsePendingTools {
4808 tool_uses: Vec<PendingToolUse>,
4809 },
4810 ToolFinished {
4811 #[allow(unused)]
4812 tool_use_id: LanguageModelToolUseId,
4813 /// The pending tool use that corresponds to this tool.
4814 pending_tool_use: Option<PendingToolUse>,
4815 },
4816 CheckpointChanged,
4817 ToolConfirmationNeeded,
4818 ToolUseLimitReached,
4819 CancelEditing,
4820 CompletionCanceled,
4821 ProfileChanged,
4822 RetriesFailed {
4823 message: SharedString,
4824 },
4825}
4826
4827impl EventEmitter<ThreadEvent> for ZedAgentThread {}
4828
4829struct PendingCompletion {
4830 id: usize,
4831 queue_state: QueueState,
4832 _task: Task<()>,
4833}
4834
4835/// Helper for extracting tool use related state from serialized messages
4836#[derive(Default)]
4837struct DeserializedToolUse {
4838 tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
4839 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
4840 tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
4841 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
4842}
4843
4844impl DeserializedToolUse {
4845 fn new(
4846 messages: &[SerializedMessage],
4847 project: &Entity<Project>,
4848 tools: &Entity<ToolWorkingSet>,
4849 window: Option<&mut Window>, // None in headless mode
4850 cx: &mut App,
4851 ) -> Self {
4852 let mut this = Self::default();
4853
4854 let mut window = window;
4855 let mut tool_names_by_id = HashMap::default();
4856
4857 for message in messages {
4858 match message.role {
4859 Role::Assistant => {
4860 if !message.tool_uses.is_empty() {
4861 let tool_uses = message
4862 .tool_uses
4863 .iter()
4864 .map(|tool_use| LanguageModelToolUse {
4865 id: tool_use.id.clone(),
4866 name: tool_use.name.clone().into(),
4867 raw_input: tool_use.input.to_string(),
4868 input: tool_use.input.clone(),
4869 is_input_complete: true,
4870 })
4871 .collect::<Vec<_>>();
4872
4873 tool_names_by_id.extend(
4874 tool_uses
4875 .iter()
4876 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
4877 );
4878
4879 this.tool_uses_by_assistant_message
4880 .insert(message.id, tool_uses);
4881
4882 for tool_result in &message.tool_results {
4883 let tool_use_id = tool_result.tool_use_id.clone();
4884 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
4885 log::warn!("no tool name found for tool use: {tool_use_id:?}");
4886 continue;
4887 };
4888
4889 this.tool_results.insert(
4890 tool_use_id.clone(),
4891 LanguageModelToolResult {
4892 tool_use_id: tool_use_id.clone(),
4893 tool_name: tool_use.clone(),
4894 is_error: tool_result.is_error,
4895 content: tool_result.content.clone(),
4896 output: tool_result.output.clone(),
4897 },
4898 );
4899
4900 if let Some(window) = &mut window {
4901 if let Some(tool) = tools.read(cx).tool(tool_use, cx) {
4902 if let Some(output) = tool_result.output.clone() {
4903 if let Some(card) = tool.deserialize_card(
4904 output,
4905 project.clone(),
4906 window,
4907 cx,
4908 ) {
4909 this.tool_result_cards.insert(tool_use_id, card);
4910 }
4911 }
4912 }
4913 }
4914 }
4915 }
4916 }
4917 Role::System | Role::User => {}
4918 }
4919 }
4920
4921 this
4922 }
4923}
4924
4925/// Resolves tool name conflicts by ensuring all tool names are unique.
4926///
4927/// When multiple tools have the same name, this function applies the following rules:
4928/// 1. Native tools always keep their original name
4929/// 2. Context server tools get prefixed with their server ID and an underscore
4930/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
4931/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
4932///
4933/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
4934fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
4935 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
4936 let mut tool_name = tool.name();
4937 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
4938 tool_name
4939 }
4940
4941 const MAX_TOOL_NAME_LENGTH: usize = 64;
4942
4943 let mut duplicated_tool_names = HashSet::default();
4944 let mut seen_tool_names = HashSet::default();
4945 for tool in tools {
4946 let tool_name = resolve_tool_name(tool);
4947 if seen_tool_names.contains(&tool_name) {
4948 debug_assert!(
4949 tool.source() != assistant_tool::ToolSource::Native,
4950 "There are two built-in tools with the same name: {}",
4951 tool_name
4952 );
4953 duplicated_tool_names.insert(tool_name);
4954 } else {
4955 seen_tool_names.insert(tool_name);
4956 }
4957 }
4958
4959 if duplicated_tool_names.is_empty() {
4960 return tools
4961 .into_iter()
4962 .map(|tool| (resolve_tool_name(tool), tool.clone()))
4963 .collect();
4964 }
4965
4966 tools
4967 .into_iter()
4968 .filter_map(|tool| {
4969 let mut tool_name = resolve_tool_name(tool);
4970 if !duplicated_tool_names.contains(&tool_name) {
4971 return Some((tool_name, tool.clone()));
4972 }
4973 match tool.source() {
4974 assistant_tool::ToolSource::Native => {
4975 // Built-in tools always keep their original name
4976 Some((tool_name, tool.clone()))
4977 }
4978 assistant_tool::ToolSource::ContextServer { id } => {
4979 // Context server tools are prefixed with the context server ID, and truncated if necessary
4980 tool_name.insert(0, '_');
4981 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
4982 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
4983 let mut id = id.to_string();
4984 id.truncate(len);
4985 tool_name.insert_str(0, &id);
4986 } else {
4987 tool_name.insert_str(0, &id);
4988 }
4989
4990 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
4991
4992 if seen_tool_names.contains(&tool_name) {
4993 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
4994 None
4995 } else {
4996 Some((tool_name, tool.clone()))
4997 }
4998 }
4999 }
5000 })
5001 .collect()
5002}
5003
5004#[cfg(test)]
5005mod tests {
5006 use super::*;
5007 use crate::{
5008 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
5009 };
5010
5011 // Test-specific constants
5012 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
5013 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
5014 use assistant_tool::{ToolRegistry, ToolSource};
5015
5016 use futures::future::BoxFuture;
5017 use futures::stream::BoxStream;
5018 use gpui::TestAppContext;
5019 use icons::IconName;
5020 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
5021 use language_model::{
5022 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
5023 LanguageModelProviderName, LanguageModelToolChoice,
5024 };
5025 use parking_lot::Mutex;
5026 use project::{FakeFs, Project};
5027 use prompt_store::PromptBuilder;
5028 use serde_json::json;
5029 use settings::{Settings, SettingsStore};
5030 use std::panic;
5031 use std::sync::Arc;
5032 use std::time::Duration;
5033 use theme::ThemeSettings;
5034 use util::path;
5035 use workspace::Workspace;
5036
5037 #[gpui::test]
5038 async fn test_send_to_model_basic(cx: &mut TestAppContext) {
5039 init_test_settings(cx);
5040
5041 let project = create_test_project(cx, json!({})).await;
5042
5043 let (_workspace, _thread_store, agent, _context_store, model) =
5044 setup_test_environment(cx, project.clone()).await;
5045
5046 agent.update(cx, |agent, cx| {
5047 agent.send_message("Hello", model.clone(), None, cx);
5048 });
5049
5050 let fake_model = model.as_fake();
5051 cx.run_until_parked();
5052 let pending_completions = fake_model.pending_completions();
5053
5054 let request = pending_completions.last().unwrap();
5055 assert_eq!(request.intent, Some(CompletionIntent::UserPrompt));
5056 assert_eq!(request.messages[0].role, Role::System);
5057 assert_eq!(request.messages[1].role, Role::User);
5058 assert_eq!(
5059 request.messages[1].content[0],
5060 MessageContent::Text("Hello".into()),
5061 );
5062 assert_eq!(agent.read_with(cx, |agent, _| agent.is_generating()), true);
5063
5064 simulate_successful_response(&fake_model, cx);
5065 cx.run_until_parked();
5066
5067 assert_eq!(agent.read_with(cx, |agent, _| agent.is_generating()), false);
5068
5069 agent.read_with(cx, |thread, _cx| {
5070 assert_eq!(thread.thread_messages[0].role, Role::User);
5071 assert_eq!(
5072 &thread.thread_messages[0].segments[0],
5073 &MessageSegment::Text("Hello".to_string())
5074 );
5075 assert_eq!(thread.thread_messages[1].role, Role::Assistant);
5076 assert_eq!(
5077 &thread.thread_messages[1].segments[0],
5078 &MessageSegment::Text("Assistant response".to_string())
5079 )
5080 });
5081 }
5082
5083 #[gpui::test]
5084 async fn test_send_to_model_with_tools(cx: &mut TestAppContext) {
5085 init_test_settings(cx);
5086
5087 let project = create_test_project(cx, json!({})).await;
5088
5089 let (_workspace, thread_store, agent, _context_store, model) =
5090 setup_test_environment(cx, project.clone()).await;
5091
5092 thread_store.update(cx, |thread_store, cx| {
5093 thread_store.tools().update(cx, |tools, _| {
5094 tools.insert(Arc::new(TestTool::new(
5095 "read_file",
5096 ToolSource::Native,
5097 Ok("the lazy dog...".to_string()),
5098 )));
5099 });
5100 });
5101
5102 agent.update(cx, |agent, cx| {
5103 agent.send_message("Read foo.txt", model.clone(), None, cx);
5104 });
5105
5106 let fake_model = model.as_fake();
5107 cx.run_until_parked();
5108
5109 assert_eq!(agent.read_with(cx, |agent, _| agent.is_generating()), true);
5110
5111 let pending_completions = fake_model.pending_completions();
5112 let request = pending_completions.last().unwrap();
5113 assert_eq!(request.intent, Some(CompletionIntent::UserPrompt));
5114 assert_eq!(request.messages[0].role, Role::System);
5115 assert_eq!(request.messages[1].role, Role::User);
5116 assert_eq!(
5117 request.messages[1].content[0],
5118 MessageContent::Text("Read foo.txt".into()),
5119 );
5120
5121 fake_model.stream_last_completion_response("I'll do so");
5122 fake_model.stream_last_completion_response(LanguageModelToolUse {
5123 id: "id".into(),
5124 name: "read_file".into(),
5125 raw_input: "foo.txt".into(),
5126 input: "foo.txt".into(),
5127 is_input_complete: true,
5128 });
5129 fake_model.end_last_completion_stream();
5130 cx.run_until_parked();
5131
5132 assert_eq!(agent.read_with(cx, |agent, _| agent.is_generating()), true);
5133
5134 let pending_completions = fake_model.pending_completions();
5135 let request = pending_completions.last().unwrap();
5136 assert_eq!(request.intent, Some(CompletionIntent::ToolResults));
5137 assert_eq!(request.messages.len(), 4);
5138
5139 let tool_result = &request.messages[3].content[0].as_tool_result().unwrap();
5140 assert_eq!(tool_result.tool_name.as_ref(), "read_file");
5141 assert_eq!(
5142 tool_result.content,
5143 LanguageModelToolResultContent::Text("the lazy dog...".into())
5144 );
5145
5146 agent.read_with(cx, |thread, _cx| {
5147 assert_eq!(thread.thread_messages[0].role, Role::User);
5148 assert_eq!(
5149 &thread.thread_messages[0].segments[0],
5150 &MessageSegment::Text("Read foo.txt".to_string())
5151 );
5152 assert_eq!(thread.thread_messages[1].role, Role::Assistant);
5153 assert_eq!(
5154 &thread.thread_messages[1].segments[0],
5155 &MessageSegment::Text("I'll do so".to_string())
5156 );
5157
5158 let MessageSegment::ToolUse(ToolUseSegment { name, output, .. }) =
5159 &thread.thread_messages[1].segments[1]
5160 else {
5161 panic!("Expected ToolUse segment")
5162 };
5163 assert_eq!(name.as_ref(), "read_file");
5164 assert_eq!(
5165 output.as_ref().unwrap().as_ref().unwrap(),
5166 &LanguageModelToolResultContent::Text("the lazy dog...".into())
5167 );
5168 });
5169
5170 assert_eq!(agent.read_with(cx, |agent, _| agent.is_generating()), true);
5171
5172 fake_model.stream_last_completion_response("Great!");
5173 fake_model.end_last_completion_stream();
5174 cx.run_until_parked();
5175 assert_eq!(agent.read_with(cx, |agent, _| agent.is_generating()), false);
5176 }
5177
5178 #[gpui::test]
5179 async fn test_message_with_context(cx: &mut TestAppContext) {
5180 init_test_settings(cx);
5181
5182 let project = create_test_project(
5183 cx,
5184 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
5185 )
5186 .await;
5187
5188 let (_workspace, _thread_store, agent, context_store, model) =
5189 setup_test_environment(cx, project.clone()).await;
5190
5191 add_file_to_context(&project, &context_store, "test/code.rs", cx)
5192 .await
5193 .unwrap();
5194
5195 let context =
5196 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
5197 let loaded_context = cx
5198 .update(|cx| load_context(vec![context], &project, &None, cx))
5199 .await;
5200
5201 // Insert user message with context
5202 let message_id = agent.update(cx, |agent, cx| {
5203 agent.send_message(
5204 UserMessageParams {
5205 text: "Please explain this code".to_string(),
5206 creases: Vec::new(),
5207 checkpoint: None,
5208 context: loaded_context,
5209 },
5210 model.clone(),
5211 None,
5212 cx,
5213 )
5214 });
5215
5216 // Check content and context in message object
5217 let message = agent.read_with(cx, |thread, _cx| {
5218 thread.message(message_id).unwrap().clone()
5219 });
5220
5221 // Use different path format strings based on platform for the test
5222 #[cfg(windows)]
5223 let path_part = r"test\code.rs";
5224 #[cfg(not(windows))]
5225 let path_part = "test/code.rs";
5226
5227 let expected_context = format!(
5228 r#"
5229<context>
5230The following items were attached by the user. They are up-to-date and don't need to be re-read.
5231
5232<files>
5233```rs {path_part}
5234fn main() {{
5235 println!("Hello, world!");
5236}}
5237```
5238</files>
5239</context>
5240"#
5241 );
5242
5243 assert_eq!(message.role, Role::User);
5244 assert_eq!(message.segments.len(), 1);
5245 assert!(matches!(
5246 &message.segments[0],
5247 MessageSegment::Text(txt) if txt == "Please explain this code",
5248 ));
5249 assert_eq!(message.loaded_context.text, expected_context);
5250
5251 // Check message in request
5252 let request = agent.update(cx, |agent, cx| {
5253 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5254 });
5255
5256 assert_eq!(request.messages.len(), 2);
5257 let expected_full_message = format!("{}Please explain this code", expected_context);
5258 assert_eq!(request.messages[1].string_contents(), expected_full_message);
5259 }
5260
5261 #[gpui::test]
5262 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
5263 init_test_settings(cx);
5264
5265 let project = create_test_project(
5266 cx,
5267 json!({
5268 "file1.rs": "fn function1() {}\n",
5269 "file2.rs": "fn function2() {}\n",
5270 "file3.rs": "fn function3() {}\n",
5271 "file4.rs": "fn function4() {}\n",
5272 }),
5273 )
5274 .await;
5275
5276 let (_, _thread_store, agent, context_store, model) =
5277 setup_test_environment(cx, project.clone()).await;
5278
5279 // First message with context 1
5280 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
5281 .await
5282 .unwrap();
5283 let new_contexts = context_store.update(cx, |store, cx| {
5284 store.new_context_for_thread(agent.read(cx), None, cx)
5285 });
5286 assert_eq!(new_contexts.len(), 1);
5287 let loaded_context = cx
5288 .update(|cx| load_context(new_contexts, &project, &None, cx))
5289 .await;
5290 let message1_id = agent.update(cx, |agent, cx| {
5291 agent.send_message(
5292 UserMessageParams {
5293 text: "Message 1".to_string(),
5294 creases: Vec::new(),
5295 checkpoint: None,
5296 context: loaded_context,
5297 },
5298 model.clone(),
5299 None,
5300 cx,
5301 )
5302 });
5303
5304 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
5305 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
5306 .await
5307 .unwrap();
5308 let new_contexts = context_store.update(cx, |store, cx| {
5309 store.new_context_for_thread(agent.read(cx), None, cx)
5310 });
5311 assert_eq!(new_contexts.len(), 1);
5312 let loaded_context = cx
5313 .update(|cx| load_context(new_contexts, &project, &None, cx))
5314 .await;
5315 let message2_id = agent.update(cx, |agent, cx| {
5316 agent.send_message(
5317 UserMessageParams {
5318 text: "Message 2".to_string(),
5319 creases: Vec::new(),
5320 checkpoint: None,
5321 context: loaded_context,
5322 },
5323 model.clone(),
5324 None,
5325 cx,
5326 )
5327 });
5328
5329 // Third message with all three contexts (contexts 1 and 2 should be skipped)
5330 //
5331 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
5332 .await
5333 .unwrap();
5334 let new_contexts = context_store.update(cx, |store, cx| {
5335 store.new_context_for_thread(agent.read(cx), None, cx)
5336 });
5337 assert_eq!(new_contexts.len(), 1);
5338 let loaded_context = cx
5339 .update(|cx| load_context(new_contexts, &project, &None, cx))
5340 .await;
5341 let message3_id = agent.update(cx, |agent, cx| {
5342 agent.send_message(
5343 UserMessageParams {
5344 text: "Message 3".to_string(),
5345 creases: Vec::new(),
5346 checkpoint: None,
5347 context: loaded_context,
5348 },
5349 model.clone(),
5350 None,
5351 cx,
5352 )
5353 });
5354
5355 // Check what contexts are included in each message
5356 let (message1, message2, message3) = agent.read_with(cx, |thread, _cx| {
5357 (
5358 thread.message(message1_id).unwrap().clone(),
5359 thread.message(message2_id).unwrap().clone(),
5360 thread.message(message3_id).unwrap().clone(),
5361 )
5362 });
5363
5364 // First message should include context 1
5365 assert!(message1.loaded_context.text.contains("file1.rs"));
5366
5367 // Second message should include only context 2 (not 1)
5368 assert!(!message2.loaded_context.text.contains("file1.rs"));
5369 assert!(message2.loaded_context.text.contains("file2.rs"));
5370
5371 // Third message should include only context 3 (not 1 or 2)
5372 assert!(!message3.loaded_context.text.contains("file1.rs"));
5373 assert!(!message3.loaded_context.text.contains("file2.rs"));
5374 assert!(message3.loaded_context.text.contains("file3.rs"));
5375
5376 // Check entire request to make sure all contexts are properly included
5377 let request = agent.update(cx, |agent, cx| {
5378 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5379 });
5380
5381 // The request should contain all 3 messages
5382 assert_eq!(request.messages.len(), 4);
5383
5384 // Check that the contexts are properly formatted in each message
5385 assert!(request.messages[1].string_contents().contains("file1.rs"));
5386 assert!(!request.messages[1].string_contents().contains("file2.rs"));
5387 assert!(!request.messages[1].string_contents().contains("file3.rs"));
5388
5389 assert!(!request.messages[2].string_contents().contains("file1.rs"));
5390 assert!(request.messages[2].string_contents().contains("file2.rs"));
5391 assert!(!request.messages[2].string_contents().contains("file3.rs"));
5392
5393 assert!(!request.messages[3].string_contents().contains("file1.rs"));
5394 assert!(!request.messages[3].string_contents().contains("file2.rs"));
5395 assert!(request.messages[3].string_contents().contains("file3.rs"));
5396
5397 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
5398 .await
5399 .unwrap();
5400 let new_contexts = context_store.update(cx, |store, cx| {
5401 store.new_context_for_thread(agent.read(cx), Some(message2_id), cx)
5402 });
5403 assert_eq!(new_contexts.len(), 3);
5404 let loaded_context = cx
5405 .update(|cx| load_context(new_contexts, &project, &None, cx))
5406 .await
5407 .loaded_context;
5408
5409 assert!(!loaded_context.text.contains("file1.rs"));
5410 assert!(loaded_context.text.contains("file2.rs"));
5411 assert!(loaded_context.text.contains("file3.rs"));
5412 assert!(loaded_context.text.contains("file4.rs"));
5413
5414 let new_contexts = context_store.update(cx, |store, cx| {
5415 // Remove file4.rs
5416 store.remove_context(&loaded_context.contexts[2].handle(), cx);
5417 store.new_context_for_thread(agent.read(cx), Some(message2_id), cx)
5418 });
5419 assert_eq!(new_contexts.len(), 2);
5420 let loaded_context = cx
5421 .update(|cx| load_context(new_contexts, &project, &None, cx))
5422 .await
5423 .loaded_context;
5424
5425 assert!(!loaded_context.text.contains("file1.rs"));
5426 assert!(loaded_context.text.contains("file2.rs"));
5427 assert!(loaded_context.text.contains("file3.rs"));
5428 assert!(!loaded_context.text.contains("file4.rs"));
5429
5430 let new_contexts = context_store.update(cx, |store, cx| {
5431 // Remove file3.rs
5432 store.remove_context(&loaded_context.contexts[1].handle(), cx);
5433 store.new_context_for_thread(agent.read(cx), Some(message2_id), cx)
5434 });
5435 assert_eq!(new_contexts.len(), 1);
5436 let loaded_context = cx
5437 .update(|cx| load_context(new_contexts, &project, &None, cx))
5438 .await
5439 .loaded_context;
5440
5441 assert!(!loaded_context.text.contains("file1.rs"));
5442 assert!(loaded_context.text.contains("file2.rs"));
5443 assert!(!loaded_context.text.contains("file3.rs"));
5444 assert!(!loaded_context.text.contains("file4.rs"));
5445 }
5446
5447 #[gpui::test]
5448 async fn test_message_without_files(cx: &mut TestAppContext) {
5449 init_test_settings(cx);
5450
5451 let project = create_test_project(
5452 cx,
5453 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
5454 )
5455 .await;
5456
5457 let (_, _thread_store, agent, _context_store, model) =
5458 setup_test_environment(cx, project.clone()).await;
5459
5460 // Insert user message without any context (empty context vector)
5461 let message_id = agent.update(cx, |agent, cx| {
5462 agent.send_message(
5463 "What is the best way to learn Rust?",
5464 model.clone(),
5465 None,
5466 cx,
5467 )
5468 });
5469
5470 // Check content and context in message object
5471 let message = agent.read_with(cx, |thread, _cx| {
5472 thread.message(message_id).unwrap().clone()
5473 });
5474
5475 // Context should be empty when no files are included
5476 assert_eq!(message.role, Role::User);
5477 assert_eq!(message.segments.len(), 1);
5478 assert!(matches!(
5479 &message.segments[0],
5480 MessageSegment::Text(txt) if txt == "What is the best way to learn Rust?",
5481 ));
5482 assert_eq!(message.loaded_context.text, "");
5483
5484 // Check message in request
5485 let request = agent.update(cx, |agent, cx| {
5486 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5487 });
5488
5489 assert_eq!(request.messages.len(), 2);
5490 assert_eq!(
5491 request.messages[1].string_contents(),
5492 "What is the best way to learn Rust?"
5493 );
5494
5495 // Add second message, also without context
5496 let message2_id = agent.update(cx, |agent, cx| {
5497 agent.send_message("Are there any good books?", model.clone(), None, cx)
5498 });
5499
5500 let message2 = agent.read_with(cx, |thread, _cx| {
5501 thread.message(message2_id).unwrap().clone()
5502 });
5503 assert_eq!(message2.loaded_context.text, "");
5504
5505 // Check that both messages appear in the request
5506 let request = agent.update(cx, |agent, cx| {
5507 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5508 });
5509
5510 assert_eq!(request.messages.len(), 3);
5511 assert_eq!(
5512 request.messages[1].string_contents(),
5513 "What is the best way to learn Rust?"
5514 );
5515 assert_eq!(
5516 request.messages[2].string_contents(),
5517 "Are there any good books?"
5518 );
5519 }
5520
5521 #[gpui::test]
5522 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
5523 init_test_settings(cx);
5524
5525 let project = create_test_project(
5526 cx,
5527 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
5528 )
5529 .await;
5530
5531 let (_workspace, thread_store, agent, _context_store, _model) =
5532 setup_test_environment(cx, project.clone()).await;
5533
5534 // Check that we are starting with the default profile
5535 let profile = cx.read(|cx| agent.read(cx).profile.clone());
5536 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
5537 assert_eq!(
5538 profile,
5539 AgentProfile::new(AgentProfileId::default(), tool_set)
5540 );
5541 }
5542
5543 #[gpui::test]
5544 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
5545 init_test_settings(cx);
5546
5547 let project = create_test_project(
5548 cx,
5549 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
5550 )
5551 .await;
5552
5553 let (_workspace, thread_store, agent, _context_store, _model) =
5554 setup_test_environment(cx, project.clone()).await;
5555
5556 // Profile gets serialized with default values
5557 let serialized = agent
5558 .update(cx, |agent, cx| agent.serialize(cx))
5559 .await
5560 .unwrap();
5561
5562 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
5563
5564 let deserialized = cx.update(|cx| {
5565 agent.update(cx, |agent, cx| {
5566 ZedAgentThread::deserialize(
5567 agent.id().clone(),
5568 serialized,
5569 agent.project.clone(),
5570 agent.tools.clone(),
5571 agent.prompt_builder.clone(),
5572 agent.project_context.clone(),
5573 None,
5574 cx,
5575 )
5576 })
5577 });
5578 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
5579
5580 assert_eq!(
5581 deserialized.profile,
5582 AgentProfile::new(AgentProfileId::default(), tool_set)
5583 );
5584 }
5585
5586 #[gpui::test]
5587 async fn test_temperature_setting(cx: &mut TestAppContext) {
5588 init_test_settings(cx);
5589
5590 let project = create_test_project(
5591 cx,
5592 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
5593 )
5594 .await;
5595
5596 let (_workspace, _thread_store, agent, _context_store, model) =
5597 setup_test_environment(cx, project.clone()).await;
5598
5599 // Both model and provider
5600 cx.update(|cx| {
5601 AgentSettings::override_global(
5602 AgentSettings {
5603 model_parameters: vec![LanguageModelParameters {
5604 provider: Some(model.provider_id().0.to_string().into()),
5605 model: Some(model.id().0.clone()),
5606 temperature: Some(0.66),
5607 }],
5608 ..AgentSettings::get_global(cx).clone()
5609 },
5610 cx,
5611 );
5612 });
5613
5614 let request = agent.update(cx, |agent, cx| {
5615 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5616 });
5617 assert_eq!(request.temperature, Some(0.66));
5618
5619 // Only model
5620 cx.update(|cx| {
5621 AgentSettings::override_global(
5622 AgentSettings {
5623 model_parameters: vec![LanguageModelParameters {
5624 provider: None,
5625 model: Some(model.id().0.clone()),
5626 temperature: Some(0.66),
5627 }],
5628 ..AgentSettings::get_global(cx).clone()
5629 },
5630 cx,
5631 );
5632 });
5633
5634 let request = agent.update(cx, |agent, cx| {
5635 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5636 });
5637 assert_eq!(request.temperature, Some(0.66));
5638
5639 // Only provider
5640 cx.update(|cx| {
5641 AgentSettings::override_global(
5642 AgentSettings {
5643 model_parameters: vec![LanguageModelParameters {
5644 provider: Some(model.provider_id().0.to_string().into()),
5645 model: None,
5646 temperature: Some(0.66),
5647 }],
5648 ..AgentSettings::get_global(cx).clone()
5649 },
5650 cx,
5651 );
5652 });
5653
5654 let request = agent.update(cx, |agent, cx| {
5655 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5656 });
5657 assert_eq!(request.temperature, Some(0.66));
5658
5659 // Same model name, different provider
5660 cx.update(|cx| {
5661 AgentSettings::override_global(
5662 AgentSettings {
5663 model_parameters: vec![LanguageModelParameters {
5664 provider: Some("anthropic".into()),
5665 model: Some(model.id().0.clone()),
5666 temperature: Some(0.66),
5667 }],
5668 ..AgentSettings::get_global(cx).clone()
5669 },
5670 cx,
5671 );
5672 });
5673
5674 let request = agent.update(cx, |agent, cx| {
5675 agent.build_request(&model, CompletionIntent::UserPrompt, cx)
5676 });
5677 assert_eq!(request.temperature, None);
5678 }
5679
5680 #[gpui::test]
5681 async fn test_thread_summary(cx: &mut TestAppContext) {
5682 init_test_settings(cx);
5683
5684 let project = create_test_project(cx, json!({})).await;
5685
5686 let (_, _thread_store, agent, _context_store, model) =
5687 setup_test_environment(cx, project.clone()).await;
5688
5689 // Initial state should be pending
5690 agent.read_with(cx, |thread, _| {
5691 assert!(matches!(thread.summary(), ThreadSummary::Pending));
5692 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5693 });
5694
5695 // Manually setting the summary should not be allowed in this state
5696 agent.update(cx, |thread, cx| {
5697 thread.set_summary("This should not work", cx);
5698 });
5699
5700 agent.read_with(cx, |thread, _| {
5701 assert!(matches!(thread.summary(), ThreadSummary::Pending));
5702 });
5703
5704 // Send a message
5705 agent.update(cx, |agent, cx| {
5706 agent.send_message("Hi", model.clone(), None, cx);
5707 });
5708
5709 let fake_model = model.as_fake();
5710 simulate_successful_response(&fake_model, cx);
5711
5712 // Should start generating summary when there are >= 2 messages
5713 agent.read_with(cx, |thread, _| {
5714 assert_eq!(*thread.summary(), ThreadSummary::Generating);
5715 });
5716
5717 // Should not be able to set the summary while generating
5718 agent.update(cx, |thread, cx| {
5719 thread.set_summary("This should not work either", cx);
5720 });
5721
5722 agent.read_with(cx, |thread, _| {
5723 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5724 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5725 });
5726
5727 cx.run_until_parked();
5728 fake_model.stream_last_completion_response("Brief");
5729 fake_model.stream_last_completion_response(" Introduction");
5730 fake_model.end_last_completion_stream();
5731 cx.run_until_parked();
5732
5733 // Summary should be set
5734 agent.read_with(cx, |thread, _| {
5735 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
5736 assert_eq!(thread.summary().or_default(), "Brief Introduction");
5737 });
5738
5739 // Now we should be able to set a summary
5740 agent.update(cx, |thread, cx| {
5741 thread.set_summary("Brief Intro", cx);
5742 });
5743
5744 agent.read_with(cx, |thread, _| {
5745 assert_eq!(thread.summary().or_default(), "Brief Intro");
5746 });
5747
5748 // Test setting an empty summary (should default to DEFAULT)
5749 agent.update(cx, |thread, cx| {
5750 thread.set_summary("", cx);
5751 });
5752
5753 agent.read_with(cx, |thread, _| {
5754 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
5755 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5756 });
5757 }
5758
5759 #[gpui::test]
5760 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
5761 init_test_settings(cx);
5762
5763 let project = create_test_project(cx, json!({})).await;
5764
5765 let (_, _thread_store, agent, _context_store, model) =
5766 setup_test_environment(cx, project.clone()).await;
5767
5768 test_summarize_error(&model, &agent, cx);
5769
5770 // Now we should be able to set a summary
5771 agent.update(cx, |thread, cx| {
5772 thread.set_summary("Brief Intro", cx);
5773 });
5774
5775 agent.read_with(cx, |thread, _| {
5776 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
5777 assert_eq!(thread.summary().or_default(), "Brief Intro");
5778 });
5779 }
5780
5781 #[gpui::test]
5782 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
5783 init_test_settings(cx);
5784
5785 let project = create_test_project(cx, json!({})).await;
5786
5787 let (_, _thread_store, agent, _context_store, model) =
5788 setup_test_environment(cx, project.clone()).await;
5789
5790 test_summarize_error(&model, &agent, cx);
5791
5792 // Sending another message should not trigger another summarize request
5793 agent.update(cx, |agent, cx| {
5794 agent.send_message("How are you?", model.clone(), None, cx);
5795 });
5796
5797 let fake_model = model.as_fake();
5798 simulate_successful_response(&fake_model, cx);
5799
5800 agent.read_with(cx, |thread, _| {
5801 // State is still Error, not Generating
5802 assert!(matches!(thread.summary(), ThreadSummary::Error));
5803 });
5804
5805 // But the summarize request can be invoked manually
5806 agent.update(cx, |agent, cx| {
5807 agent.summarize(cx);
5808 });
5809
5810 agent.read_with(cx, |thread, _| {
5811 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5812 });
5813
5814 cx.run_until_parked();
5815 fake_model.stream_last_completion_response("A successful summary");
5816 fake_model.end_last_completion_stream();
5817 cx.run_until_parked();
5818
5819 agent.read_with(cx, |thread, _| {
5820 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
5821 assert_eq!(thread.summary().or_default(), "A successful summary");
5822 });
5823 }
5824
5825 #[gpui::test]
5826 fn test_resolve_tool_name_conflicts() {
5827 assert_resolve_tool_name_conflicts(
5828 vec![
5829 TestTool::new("tool1", ToolSource::Native, Ok("")),
5830 TestTool::new("tool2", ToolSource::Native, Ok("")),
5831 TestTool::new(
5832 "tool3",
5833 ToolSource::ContextServer { id: "mcp-1".into() },
5834 Ok(""),
5835 ),
5836 ],
5837 vec!["tool1", "tool2", "tool3"],
5838 );
5839
5840 assert_resolve_tool_name_conflicts(
5841 vec![
5842 TestTool::new("tool1", ToolSource::Native, Ok("")),
5843 TestTool::new("tool2", ToolSource::Native, Ok("")),
5844 TestTool::new(
5845 "tool3",
5846 ToolSource::ContextServer { id: "mcp-1".into() },
5847 Ok(""),
5848 ),
5849 TestTool::new(
5850 "tool3",
5851 ToolSource::ContextServer { id: "mcp-2".into() },
5852 Ok(""),
5853 ),
5854 ],
5855 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
5856 );
5857
5858 assert_resolve_tool_name_conflicts(
5859 vec![
5860 TestTool::new("tool1", ToolSource::Native, Ok("")),
5861 TestTool::new("tool2", ToolSource::Native, Ok("")),
5862 TestTool::new("tool3", ToolSource::Native, Ok("")),
5863 TestTool::new(
5864 "tool3",
5865 ToolSource::ContextServer { id: "mcp-1".into() },
5866 Ok(""),
5867 ),
5868 TestTool::new(
5869 "tool3",
5870 ToolSource::ContextServer { id: "mcp-2".into() },
5871 Ok(""),
5872 ),
5873 ],
5874 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
5875 );
5876
5877 // Test that tool with very long name is always truncated
5878 assert_resolve_tool_name_conflicts(
5879 vec![TestTool::new(
5880 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
5881 ToolSource::Native,
5882 Ok(""),
5883 )],
5884 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
5885 );
5886
5887 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
5888 assert_resolve_tool_name_conflicts(
5889 vec![
5890 TestTool::new(
5891 "tool-with-very-very-very-long-name",
5892 ToolSource::Native,
5893 Ok(""),
5894 ),
5895 TestTool::new(
5896 "tool-with-very-very-very-long-name",
5897 ToolSource::ContextServer {
5898 id: "mcp-with-very-very-very-long-name".into(),
5899 },
5900 Ok(""),
5901 ),
5902 ],
5903 vec![
5904 "tool-with-very-very-very-long-name",
5905 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
5906 ],
5907 );
5908
5909 fn assert_resolve_tool_name_conflicts(
5910 tools: Vec<TestTool>,
5911 expected: Vec<impl Into<String>>,
5912 ) {
5913 let tools: Vec<Arc<dyn Tool>> = tools
5914 .into_iter()
5915 .map(|t| Arc::new(t) as Arc<dyn Tool>)
5916 .collect();
5917 let tools = resolve_tool_name_conflicts(&tools);
5918 assert_eq!(tools.len(), expected.len());
5919 for (i, expected_name) in expected.into_iter().enumerate() {
5920 let expected_name = expected_name.into();
5921 let actual_name = &tools[i].0;
5922 assert_eq!(
5923 actual_name, &expected_name,
5924 "Expected '{}' got '{}' at index {}",
5925 expected_name, actual_name, i
5926 );
5927 }
5928 }
5929 }
5930
5931 struct TestTool {
5932 name: String,
5933 source: ToolSource,
5934 result: Result<String, String>,
5935 }
5936
5937 impl TestTool {
5938 fn new(
5939 name: impl Into<String>,
5940 source: ToolSource,
5941 result: Result<impl Into<String>, String>,
5942 ) -> Self {
5943 Self {
5944 name: name.into(),
5945 result: result.map(|r| r.into()),
5946 source,
5947 }
5948 }
5949 }
5950
5951 impl Tool for TestTool {
5952 fn name(&self) -> String {
5953 self.name.clone()
5954 }
5955
5956 fn icon(&self) -> IconName {
5957 IconName::Ai
5958 }
5959
5960 fn may_perform_edits(&self) -> bool {
5961 false
5962 }
5963
5964 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
5965 true
5966 }
5967
5968 fn source(&self) -> ToolSource {
5969 self.source.clone()
5970 }
5971
5972 fn description(&self) -> String {
5973 "Test tool".to_string()
5974 }
5975
5976 fn ui_text(&self, _input: &serde_json::Value) -> String {
5977 "Test tool".to_string()
5978 }
5979
5980 fn run(
5981 self: Arc<Self>,
5982 _input: serde_json::Value,
5983 _request: Arc<LanguageModelRequest>,
5984 _project: Entity<Project>,
5985 _action_log: Entity<ActionLog>,
5986 _model: Arc<dyn LanguageModel>,
5987 _window: Option<AnyWindowHandle>,
5988 _cx: &mut App,
5989 ) -> assistant_tool::ToolResult {
5990 assistant_tool::ToolResult {
5991 output: Task::ready(match self.result.clone() {
5992 Ok(content) => Ok(ToolResultOutput {
5993 content: ToolResultContent::Text(content),
5994 output: None,
5995 }),
5996 Err(e) => Err(anyhow!(e)),
5997 }),
5998 card: None,
5999 }
6000 }
6001 }
6002
6003 // Helper to create a model that returns errors
6004 enum TestError {
6005 Overloaded,
6006 InternalServerError,
6007 }
6008
6009 #[gpui::test]
6010 async fn test_retry_single_attempt(cx: &mut TestAppContext) {
6011 init_test_settings(cx);
6012
6013 let project = create_test_project(cx, json!({})).await;
6014 let (_workspace, _thread_store, agent, _context_store, _model) =
6015 setup_test_environment(cx, project.clone()).await;
6016
6017 // Create a model that fails once then succeeds
6018 let attempt_count = Arc::new(Mutex::new(0));
6019 let attempt_count_clone = attempt_count.clone();
6020
6021 let retry_model = Arc::new(RetryTestModel {
6022 inner: Arc::new(FakeLanguageModel::default()),
6023 attempt_count: attempt_count_clone,
6024 fail_attempts: 1,
6025 error_type: TestError::Overloaded,
6026 });
6027
6028 agent.update(cx, |agent, cx| {
6029 agent.send_message("Hello", retry_model.clone(), None, cx);
6030 });
6031
6032 // First attempt should fail
6033 cx.run_until_parked();
6034 assert_eq!(*attempt_count.lock(), 1);
6035
6036 // Advance clock for retry delay
6037 cx.executor().advance_clock(BASE_RETRY_DELAY);
6038 cx.run_until_parked();
6039
6040 // Second attempt should succeed
6041 assert_eq!(*attempt_count.lock(), 2);
6042
6043 // Simulate successful response
6044 let fake_model = retry_model.as_fake();
6045 fake_model.stream_last_completion_response("Assistant response");
6046 fake_model.end_last_completion_stream();
6047 cx.run_until_parked();
6048
6049 // Verify the message was sent successfully
6050 agent.read_with(cx, |thread, _cx| {
6051 assert_eq!(thread.thread_messages.len(), 2);
6052 assert_eq!(thread.thread_messages[0].role, Role::User);
6053 assert_eq!(thread.thread_messages[1].role, Role::Assistant);
6054 assert_eq!(
6055 &thread.thread_messages[1].segments[0],
6056 &MessageSegment::Text("Assistant response".to_string())
6057 );
6058 });
6059 }
6060
6061 #[gpui::test]
6062 async fn test_retry_max_attempts_exceeded(cx: &mut TestAppContext) {
6063 init_test_settings(cx);
6064
6065 let project = create_test_project(cx, json!({})).await;
6066 let (_workspace, _thread_store, agent, _context_store, _model) =
6067 setup_test_environment(cx, project.clone()).await;
6068
6069 // Create a model that always fails
6070 let attempt_count = Arc::new(Mutex::new(0));
6071 let attempt_count_clone = attempt_count.clone();
6072
6073 let retry_model = Arc::new(RetryTestModel {
6074 inner: Arc::new(FakeLanguageModel::default()),
6075 attempt_count: attempt_count_clone,
6076 fail_attempts: (MAX_RETRY_ATTEMPTS + 1) as usize,
6077 error_type: TestError::InternalServerError,
6078 });
6079
6080 agent.update(cx, |agent, cx| {
6081 agent.send_message("Hello", retry_model.clone(), None, cx);
6082 });
6083
6084 // Run through all retry attempts
6085 for attempt in 1..=MAX_RETRY_ATTEMPTS {
6086 cx.run_until_parked();
6087 assert_eq!(*attempt_count.lock(), attempt as usize);
6088
6089 if attempt < MAX_RETRY_ATTEMPTS {
6090 // Advance clock for exponential backoff
6091 let delay = BASE_RETRY_DELAY * 2_u32.pow((attempt - 1) as u32);
6092 cx.executor().advance_clock(delay);
6093 cx.run_until_parked();
6094 }
6095 }
6096
6097 cx.run_until_parked();
6098
6099 // Should not retry beyond MAX_RETRY_ATTEMPTS
6100 assert_eq!(*attempt_count.lock(), MAX_RETRY_ATTEMPTS as usize);
6101
6102 // Verify no messages were added (failure case)
6103 agent.read_with(cx, |agent, _cx| {
6104 assert_eq!(agent.thread_messages.len(), 1); // Only user message
6105 assert_eq!(agent.thread_messages[0].role, Role::User);
6106 });
6107 }
6108
6109 #[gpui::test]
6110 async fn test_retry_exponential_backoff(cx: &mut TestAppContext) {
6111 init_test_settings(cx);
6112
6113 let project = create_test_project(cx, json!({})).await;
6114 let (_workspace, _thread_store, agent, _context_store, _model) =
6115 setup_test_environment(cx, project.clone()).await;
6116
6117 // Create a model that fails multiple times
6118 let attempt_count = Arc::new(Mutex::new(0));
6119 let attempt_count_clone = attempt_count.clone();
6120
6121 let retry_model = Arc::new(RetryTestModel {
6122 inner: Arc::new(FakeLanguageModel::default()),
6123 attempt_count: attempt_count_clone,
6124 fail_attempts: 3,
6125 error_type: TestError::Overloaded,
6126 });
6127
6128 agent.update(cx, |agent, cx| {
6129 agent.send_message("Hello", retry_model.clone(), None, cx);
6130 });
6131
6132 // First attempt
6133 cx.run_until_parked();
6134 assert_eq!(*attempt_count.lock(), 1);
6135
6136 // Second attempt after 5 seconds
6137 cx.executor().advance_clock(BASE_RETRY_DELAY);
6138 cx.run_until_parked();
6139 assert_eq!(*attempt_count.lock(), 2);
6140
6141 // Third attempt after 10 seconds (5 * 2^1)
6142 cx.executor().advance_clock(BASE_RETRY_DELAY * 2);
6143 cx.run_until_parked();
6144 assert_eq!(*attempt_count.lock(), 3);
6145
6146 // Fourth attempt after 20 seconds (5 * 2^2)
6147 cx.executor().advance_clock(BASE_RETRY_DELAY * 4);
6148 cx.run_until_parked();
6149 assert_eq!(*attempt_count.lock(), 4);
6150
6151 // Simulate successful response
6152 let fake_model = retry_model.as_fake();
6153 fake_model.stream_last_completion_response("Assistant response");
6154 fake_model.end_last_completion_stream();
6155 cx.run_until_parked();
6156
6157 // Verify the message was sent successfully
6158 agent.read_with(cx, |agent, _cx| {
6159 assert_eq!(agent.thread_messages.len(), 2);
6160 assert_eq!(agent.thread_messages[0].role, Role::User);
6161 assert_eq!(agent.thread_messages[1].role, Role::Assistant);
6162 });
6163 }
6164
6165 #[gpui::test]
6166 async fn test_retry_rate_limit_with_custom_delay(cx: &mut TestAppContext) {
6167 init_test_settings(cx);
6168
6169 let project = create_test_project(cx, json!({})).await;
6170 let (_workspace, _thread_store, agent, _context_store, _model) =
6171 setup_test_environment(cx, project.clone()).await;
6172
6173 // Create a model that returns rate limit error with custom delay
6174 let custom_delay = Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS);
6175 let attempt_count = Arc::new(Mutex::new(0));
6176 let attempt_count_clone = attempt_count.clone();
6177
6178 let retry_model = Arc::new(RateLimitTestModel {
6179 inner: Arc::new(FakeLanguageModel::default()),
6180 attempt_count: attempt_count_clone,
6181 fail_attempts: 1,
6182 retry_after: custom_delay,
6183 });
6184
6185 agent.update(cx, |agent, cx| {
6186 agent.send_message("Hello", retry_model.clone(), None, cx);
6187 });
6188
6189 // First attempt should fail with rate limit
6190 cx.run_until_parked();
6191 assert_eq!(*attempt_count.lock(), 1);
6192
6193 // Advance clock by less than custom delay - should not retry yet
6194 cx.executor().advance_clock(custom_delay / 2);
6195 cx.run_until_parked();
6196 assert_eq!(*attempt_count.lock(), 1);
6197
6198 // Advance clock to complete custom delay
6199 cx.executor().advance_clock(custom_delay / 2);
6200 cx.run_until_parked();
6201
6202 // Second attempt should succeed
6203 assert_eq!(*attempt_count.lock(), 2);
6204
6205 // Simulate successful response
6206 let fake_model = retry_model.as_fake();
6207 fake_model.stream_last_completion_response("Assistant response");
6208 fake_model.end_last_completion_stream();
6209 cx.run_until_parked();
6210
6211 // Verify success
6212 agent.read_with(cx, |thread, _cx| {
6213 assert_eq!(thread.thread_messages.len(), 2);
6214 assert_eq!(thread.thread_messages[1].role, Role::Assistant);
6215 });
6216 }
6217
6218 // Test model that fails a specific number of times
6219 struct RetryTestModel {
6220 inner: Arc<FakeLanguageModel>,
6221 attempt_count: Arc<Mutex<usize>>,
6222 fail_attempts: usize,
6223 error_type: TestError,
6224 }
6225
6226 impl LanguageModel for RetryTestModel {
6227 fn id(&self) -> LanguageModelId {
6228 self.inner.id()
6229 }
6230
6231 fn name(&self) -> LanguageModelName {
6232 self.inner.name()
6233 }
6234
6235 fn provider_id(&self) -> LanguageModelProviderId {
6236 self.inner.provider_id()
6237 }
6238
6239 fn provider_name(&self) -> LanguageModelProviderName {
6240 self.inner.provider_name()
6241 }
6242
6243 fn supports_tools(&self) -> bool {
6244 self.inner.supports_tools()
6245 }
6246
6247 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
6248 self.inner.supports_tool_choice(choice)
6249 }
6250
6251 fn supports_images(&self) -> bool {
6252 self.inner.supports_images()
6253 }
6254
6255 fn telemetry_id(&self) -> String {
6256 self.inner.telemetry_id()
6257 }
6258
6259 fn max_token_count(&self) -> u64 {
6260 self.inner.max_token_count()
6261 }
6262
6263 fn count_tokens(
6264 &self,
6265 request: LanguageModelRequest,
6266 cx: &App,
6267 ) -> BoxFuture<'static, Result<u64>> {
6268 self.inner.count_tokens(request, cx)
6269 }
6270
6271 fn stream_completion(
6272 &self,
6273 request: LanguageModelRequest,
6274 cx: &AsyncApp,
6275 ) -> BoxFuture<
6276 'static,
6277 Result<
6278 BoxStream<
6279 'static,
6280 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
6281 >,
6282 LanguageModelCompletionError,
6283 >,
6284 > {
6285 let mut count = self.attempt_count.lock();
6286 *count += 1;
6287 let current_attempt = *count;
6288 drop(count);
6289
6290 if current_attempt <= self.fail_attempts {
6291 let error = match self.error_type {
6292 TestError::Overloaded => LanguageModelCompletionError::Overloaded,
6293 TestError::InternalServerError => {
6294 LanguageModelCompletionError::ApiInternalServerError
6295 }
6296 };
6297 async move { Err(error) }.boxed()
6298 } else {
6299 self.inner.stream_completion(request, cx)
6300 }
6301 }
6302
6303 fn as_fake(&self) -> &FakeLanguageModel {
6304 &self.inner
6305 }
6306 }
6307
6308 // Test model for rate limit errors
6309 struct RateLimitTestModel {
6310 inner: Arc<FakeLanguageModel>,
6311 attempt_count: Arc<Mutex<usize>>,
6312 fail_attempts: usize,
6313 retry_after: Duration,
6314 }
6315
6316 impl LanguageModel for RateLimitTestModel {
6317 fn id(&self) -> LanguageModelId {
6318 self.inner.id()
6319 }
6320
6321 fn name(&self) -> LanguageModelName {
6322 self.inner.name()
6323 }
6324
6325 fn provider_id(&self) -> LanguageModelProviderId {
6326 self.inner.provider_id()
6327 }
6328
6329 fn provider_name(&self) -> LanguageModelProviderName {
6330 self.inner.provider_name()
6331 }
6332
6333 fn supports_tools(&self) -> bool {
6334 self.inner.supports_tools()
6335 }
6336
6337 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
6338 self.inner.supports_tool_choice(choice)
6339 }
6340
6341 fn supports_images(&self) -> bool {
6342 self.inner.supports_images()
6343 }
6344
6345 fn telemetry_id(&self) -> String {
6346 self.inner.telemetry_id()
6347 }
6348
6349 fn max_token_count(&self) -> u64 {
6350 self.inner.max_token_count()
6351 }
6352
6353 fn count_tokens(
6354 &self,
6355 request: LanguageModelRequest,
6356 cx: &App,
6357 ) -> BoxFuture<'static, Result<u64>> {
6358 self.inner.count_tokens(request, cx)
6359 }
6360
6361 fn stream_completion(
6362 &self,
6363 request: LanguageModelRequest,
6364 cx: &AsyncApp,
6365 ) -> BoxFuture<
6366 'static,
6367 Result<
6368 BoxStream<
6369 'static,
6370 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
6371 >,
6372 LanguageModelCompletionError,
6373 >,
6374 > {
6375 let mut count = self.attempt_count.lock();
6376 *count += 1;
6377 let current_attempt = *count;
6378 drop(count);
6379
6380 if current_attempt <= self.fail_attempts {
6381 let error = LanguageModelCompletionError::RateLimitExceeded {
6382 retry_after: self.retry_after,
6383 };
6384 async move { Err(error) }.boxed()
6385 } else {
6386 self.inner.stream_completion(request, cx)
6387 }
6388 }
6389
6390 fn as_fake(&self) -> &FakeLanguageModel {
6391 &self.inner
6392 }
6393 }
6394
6395 fn test_summarize_error(
6396 model: &Arc<dyn LanguageModel>,
6397 agent: &Entity<ZedAgentThread>,
6398 cx: &mut TestAppContext,
6399 ) {
6400 agent.update(cx, |agent, cx| {
6401 agent.send_message("Hi", model.clone(), None, cx);
6402 });
6403
6404 let fake_model = model.as_fake();
6405 simulate_successful_response(&fake_model, cx);
6406
6407 agent.read_with(cx, |thread, _| {
6408 assert!(matches!(thread.summary(), ThreadSummary::Generating));
6409 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
6410 });
6411
6412 // Simulate summary request ending
6413 cx.run_until_parked();
6414 fake_model.end_last_completion_stream();
6415 cx.run_until_parked();
6416
6417 // State is set to Error and default message
6418 agent.read_with(cx, |thread, _| {
6419 assert!(matches!(thread.summary(), ThreadSummary::Error));
6420 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
6421 });
6422 }
6423
6424 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
6425 cx.run_until_parked();
6426 fake_model.stream_last_completion_response("Assist");
6427 fake_model.stream_last_completion_response("ant response");
6428 fake_model.end_last_completion_stream();
6429 cx.run_until_parked();
6430 }
6431
6432 fn init_test_settings(cx: &mut TestAppContext) {
6433 cx.update(|cx| {
6434 let settings_store = SettingsStore::test(cx);
6435 cx.set_global(settings_store);
6436 language::init(cx);
6437 Project::init_settings(cx);
6438 AgentSettings::register(cx);
6439 prompt_store::init(cx);
6440 thread_store::init(cx);
6441 workspace::init_settings(cx);
6442 language_model::init_settings(cx);
6443 ThemeSettings::register(cx);
6444 ToolRegistry::default_global(cx);
6445 });
6446 }
6447
6448 // Helper to create a test project with test files
6449 async fn create_test_project(
6450 cx: &mut TestAppContext,
6451 files: serde_json::Value,
6452 ) -> Entity<Project> {
6453 let fs = FakeFs::new(cx.executor());
6454 fs.insert_tree(path!("/test"), files).await;
6455 Project::test(fs, [path!("/test").as_ref()], cx).await
6456 }
6457
6458 async fn setup_test_environment(
6459 cx: &mut TestAppContext,
6460 project: Entity<Project>,
6461 ) -> (
6462 Entity<Workspace>,
6463 Entity<ThreadStore>,
6464 Entity<ZedAgentThread>,
6465 Entity<ContextStore>,
6466 Arc<dyn LanguageModel>,
6467 ) {
6468 let (workspace, cx) =
6469 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
6470
6471 let thread_store = cx
6472 .update(|_, cx| {
6473 ThreadStore::load(
6474 project.clone(),
6475 cx.new(|_| ToolWorkingSet::default()),
6476 None,
6477 Arc::new(PromptBuilder::new(None).unwrap()),
6478 cx,
6479 )
6480 })
6481 .await
6482 .unwrap();
6483
6484 let agent = thread_store.update(cx, |store, cx| store.create_thread(cx));
6485 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
6486
6487 let provider = Arc::new(FakeLanguageModelProvider);
6488 let model = provider.test_model();
6489 let model: Arc<dyn LanguageModel> = Arc::new(model);
6490
6491 cx.update(|_, cx| {
6492 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
6493 registry.set_default_model(
6494 Some(ConfiguredModel {
6495 provider: provider.clone(),
6496 model: model.clone(),
6497 }),
6498 cx,
6499 );
6500 registry.set_thread_summary_model(
6501 Some(ConfiguredModel {
6502 provider,
6503 model: model.clone(),
6504 }),
6505 cx,
6506 );
6507 })
6508 });
6509
6510 (workspace, thread_store, agent, context_store, model)
6511 }
6512
6513 async fn add_file_to_context(
6514 project: &Entity<Project>,
6515 context_store: &Entity<ContextStore>,
6516 path: &str,
6517 cx: &mut TestAppContext,
6518 ) -> Result<Entity<language::Buffer>> {
6519 let buffer_path = project
6520 .read_with(cx, |project, cx| project.find_project_path(path, cx))
6521 .unwrap();
6522
6523 let buffer = project
6524 .update(cx, |project, cx| {
6525 project.open_buffer(buffer_path.clone(), cx)
6526 })
6527 .await
6528 .unwrap();
6529
6530 context_store.update(cx, |context_store, cx| {
6531 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
6532 });
6533
6534 Ok(buffer)
6535 }
6536
6537 #[gpui::test]
6538 async fn test_truncate(cx: &mut TestAppContext) {
6539 init_test_settings(cx);
6540
6541 let project = create_test_project(cx, json!({})).await;
6542 let (_workspace, _thread_store, agent, _context_store, model) =
6543 setup_test_environment(cx, project.clone()).await;
6544
6545 // Send first message
6546 let message_id_1 = agent.update(cx, |agent, cx| {
6547 agent.send_message("First message", model.clone(), None, cx)
6548 });
6549
6550 cx.run_until_parked();
6551
6552 let fake_model = model.as_fake();
6553 fake_model.stream_last_completion_response("First response");
6554 fake_model.end_last_completion_stream();
6555 cx.run_until_parked();
6556
6557 // Send second message
6558 let message_id_2 = agent.update(cx, |agent, cx| {
6559 agent.send_message("Second message", model.clone(), None, cx)
6560 });
6561
6562 cx.run_until_parked();
6563 fake_model.stream_last_completion_response("Second response");
6564 fake_model.end_last_completion_stream();
6565 cx.run_until_parked();
6566
6567 // Send third message
6568 agent.update(cx, |agent, cx| {
6569 agent.send_message("Third message", model.clone(), None, cx)
6570 });
6571
6572 // Wait for completion to be registered
6573 cx.run_until_parked();
6574 fake_model.stream_last_completion_response("Third response");
6575 fake_model.end_last_completion_stream();
6576 cx.run_until_parked();
6577
6578 // Verify we have 6 messages (3 user + 3 assistant)
6579 agent.read_with(cx, |thread, _| {
6580 assert_eq!(thread.thread_messages.len(), 6);
6581 });
6582
6583 // Truncate at the second user message
6584 agent.update(cx, |agent, cx| {
6585 agent.truncate(message_id_2, cx);
6586 });
6587
6588 // Verify truncation
6589 agent.read_with(cx, |thread, _| {
6590 assert_eq!(thread.thread_messages.len(), 2);
6591 assert_eq!(thread.thread_messages[0].id, message_id_1);
6592 assert_eq!(thread.thread_messages[0].role, Role::User);
6593 assert_eq!(thread.thread_messages[1].role, Role::Assistant);
6594
6595 // Verify the truncated messages are gone
6596 assert!(thread.message(message_id_2).is_none());
6597 });
6598
6599 // Verify internal state is consistent
6600 agent.read_with(cx, |agent, _| {
6601 assert_eq!(agent.messages.len(), 2); // Both user messages are in the messages vec
6602 assert!(!agent.thread_user_messages.contains_key(&message_id_2));
6603 });
6604 }
6605}