1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use action_log::ActionLog;
12use agent_settings::{
13 AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
14 SUMMARIZE_THREAD_PROMPT,
15};
16use anyhow::{Result, anyhow};
17use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
18use chrono::{DateTime, Utc};
19use client::{ModelRequestUsage, RequestUsage};
20use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
21use collections::HashMap;
22use futures::{FutureExt, StreamExt as _, future::Shared};
23use git::repository::DiffType;
24use gpui::{
25 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
26 WeakEntity, Window,
27};
28use http_client::StatusCode;
29use language_model::{
30 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
31 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
32 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
33 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
34 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
35 TokenUsage,
36};
37use postage::stream::Stream as _;
38use project::{
39 Project,
40 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
41};
42use prompt_store::{ModelContext, PromptBuilder};
43use schemars::JsonSchema;
44use serde::{Deserialize, Serialize};
45use settings::Settings;
46use std::{
47 io::Write,
48 ops::Range,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use thiserror::Error;
53use util::{ResultExt as _, post_inc};
54use uuid::Uuid;
55
56const MAX_RETRY_ATTEMPTS: u8 = 4;
57const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
58
59#[derive(Debug, Clone)]
60enum RetryStrategy {
61 ExponentialBackoff {
62 initial_delay: Duration,
63 max_attempts: u8,
64 },
65 Fixed {
66 delay: Duration,
67 max_attempts: u8,
68 },
69}
70
71#[derive(
72 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
73)]
74pub struct ThreadId(Arc<str>);
75
76impl ThreadId {
77 pub fn new() -> Self {
78 Self(Uuid::new_v4().to_string().into())
79 }
80}
81
82impl std::fmt::Display for ThreadId {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.0)
85 }
86}
87
88impl From<&str> for ThreadId {
89 fn from(value: &str) -> Self {
90 Self(value.into())
91 }
92}
93
94/// The ID of the user prompt that initiated a request.
95///
96/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
98pub struct PromptId(Arc<str>);
99
100impl PromptId {
101 pub fn new() -> Self {
102 Self(Uuid::new_v4().to_string().into())
103 }
104}
105
106impl std::fmt::Display for PromptId {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
113pub struct MessageId(pub usize);
114
115impl MessageId {
116 fn post_inc(&mut self) -> Self {
117 Self(post_inc(&mut self.0))
118 }
119
120 pub fn as_usize(&self) -> usize {
121 self.0
122 }
123}
124
125/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
126#[derive(Clone, Debug)]
127pub struct MessageCrease {
128 pub range: Range<usize>,
129 pub icon_path: SharedString,
130 pub label: SharedString,
131 /// None for a deserialized message, Some otherwise.
132 pub context: Option<AgentContextHandle>,
133}
134
135/// A message in a [`Thread`].
136#[derive(Debug, Clone)]
137pub struct Message {
138 pub id: MessageId,
139 pub role: Role,
140 pub segments: Vec<MessageSegment>,
141 pub loaded_context: LoadedContext,
142 pub creases: Vec<MessageCrease>,
143 pub is_hidden: bool,
144 pub ui_only: bool,
145}
146
147impl Message {
148 /// Returns whether the message contains any meaningful text that should be displayed
149 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
150 pub fn should_display_content(&self) -> bool {
151 self.segments.iter().all(|segment| segment.should_display())
152 }
153
154 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
155 if let Some(MessageSegment::Thinking {
156 text: segment,
157 signature: current_signature,
158 }) = self.segments.last_mut()
159 {
160 if let Some(signature) = signature {
161 *current_signature = Some(signature);
162 }
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Thinking {
166 text: text.to_string(),
167 signature,
168 });
169 }
170 }
171
172 pub fn push_redacted_thinking(&mut self, data: String) {
173 self.segments.push(MessageSegment::RedactedThinking(data));
174 }
175
176 pub fn push_text(&mut self, text: &str) {
177 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
178 segment.push_str(text);
179 } else {
180 self.segments.push(MessageSegment::Text(text.to_string()));
181 }
182 }
183
184 pub fn to_message_content(&self) -> String {
185 let mut result = String::new();
186
187 if !self.loaded_context.text.is_empty() {
188 result.push_str(&self.loaded_context.text);
189 }
190
191 for segment in &self.segments {
192 match segment {
193 MessageSegment::Text(text) => result.push_str(text),
194 MessageSegment::Thinking { text, .. } => {
195 result.push_str("<think>\n");
196 result.push_str(text);
197 result.push_str("\n</think>");
198 }
199 MessageSegment::RedactedThinking(_) => {}
200 }
201 }
202
203 result
204 }
205}
206
207#[derive(Debug, Clone, PartialEq, Eq)]
208pub enum MessageSegment {
209 Text(String),
210 Thinking {
211 text: String,
212 signature: Option<String>,
213 },
214 RedactedThinking(String),
215}
216
217impl MessageSegment {
218 pub fn should_display(&self) -> bool {
219 match self {
220 Self::Text(text) => text.is_empty(),
221 Self::Thinking { text, .. } => text.is_empty(),
222 Self::RedactedThinking(_) => false,
223 }
224 }
225
226 pub fn text(&self) -> Option<&str> {
227 match self {
228 MessageSegment::Text(text) => Some(text),
229 _ => None,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
235pub struct ProjectSnapshot {
236 pub worktree_snapshots: Vec<WorktreeSnapshot>,
237 pub timestamp: DateTime<Utc>,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
241pub struct WorktreeSnapshot {
242 pub worktree_path: String,
243 pub git_state: Option<GitState>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
247pub struct GitState {
248 pub remote_url: Option<String>,
249 pub head_sha: Option<String>,
250 pub current_branch: Option<String>,
251 pub diff: Option<String>,
252}
253
254#[derive(Clone, Debug)]
255pub struct ThreadCheckpoint {
256 message_id: MessageId,
257 git_checkpoint: GitStoreCheckpoint,
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261pub enum ThreadFeedback {
262 Positive,
263 Negative,
264}
265
266pub enum LastRestoreCheckpoint {
267 Pending {
268 message_id: MessageId,
269 },
270 Error {
271 message_id: MessageId,
272 error: String,
273 },
274}
275
276impl LastRestoreCheckpoint {
277 pub fn message_id(&self) -> MessageId {
278 match self {
279 LastRestoreCheckpoint::Pending { message_id } => *message_id,
280 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
281 }
282 }
283}
284
285#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
286pub enum DetailedSummaryState {
287 #[default]
288 NotGenerated,
289 Generating {
290 message_id: MessageId,
291 },
292 Generated {
293 text: SharedString,
294 message_id: MessageId,
295 },
296}
297
298impl DetailedSummaryState {
299 fn text(&self) -> Option<SharedString> {
300 if let Self::Generated { text, .. } = self {
301 Some(text.clone())
302 } else {
303 None
304 }
305 }
306}
307
308#[derive(Default, Debug)]
309pub struct TotalTokenUsage {
310 pub total: u64,
311 pub max: u64,
312}
313
314impl TotalTokenUsage {
315 pub fn ratio(&self) -> TokenUsageRatio {
316 #[cfg(debug_assertions)]
317 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
318 .unwrap_or("0.8".to_string())
319 .parse()
320 .unwrap();
321 #[cfg(not(debug_assertions))]
322 let warning_threshold: f32 = 0.8;
323
324 // When the maximum is unknown because there is no selected model,
325 // avoid showing the token limit warning.
326 if self.max == 0 {
327 TokenUsageRatio::Normal
328 } else if self.total >= self.max {
329 TokenUsageRatio::Exceeded
330 } else if self.total as f32 / self.max as f32 >= warning_threshold {
331 TokenUsageRatio::Warning
332 } else {
333 TokenUsageRatio::Normal
334 }
335 }
336
337 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
338 TotalTokenUsage {
339 total: self.total + tokens,
340 max: self.max,
341 }
342 }
343}
344
345#[derive(Debug, Default, PartialEq, Eq)]
346pub enum TokenUsageRatio {
347 #[default]
348 Normal,
349 Warning,
350 Exceeded,
351}
352
353#[derive(Debug, Clone, Copy)]
354pub enum QueueState {
355 Sending,
356 Queued { position: usize },
357 Started,
358}
359
360/// A thread of conversation with the LLM.
361pub struct Thread {
362 id: ThreadId,
363 updated_at: DateTime<Utc>,
364 summary: ThreadSummary,
365 pending_summary: Task<Option<()>>,
366 detailed_summary_task: Task<Option<()>>,
367 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
368 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
369 completion_mode: agent_settings::CompletionMode,
370 messages: Vec<Message>,
371 next_message_id: MessageId,
372 last_prompt_id: PromptId,
373 project_context: SharedProjectContext,
374 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
375 completion_count: usize,
376 pending_completions: Vec<PendingCompletion>,
377 project: Entity<Project>,
378 prompt_builder: Arc<PromptBuilder>,
379 tools: Entity<ToolWorkingSet>,
380 tool_use: ToolUseState,
381 action_log: Entity<ActionLog>,
382 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
383 pending_checkpoint: Option<ThreadCheckpoint>,
384 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
385 request_token_usage: Vec<TokenUsage>,
386 cumulative_token_usage: TokenUsage,
387 exceeded_window_error: Option<ExceededWindowError>,
388 tool_use_limit_reached: bool,
389 retry_state: Option<RetryState>,
390 message_feedback: HashMap<MessageId, ThreadFeedback>,
391 last_received_chunk_at: Option<Instant>,
392 request_callback: Option<
393 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
394 >,
395 remaining_turns: u32,
396 configured_model: Option<ConfiguredModel>,
397 profile: AgentProfile,
398 last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
399}
400
401#[derive(Clone, Debug)]
402struct RetryState {
403 attempt: u8,
404 max_attempts: u8,
405 intent: CompletionIntent,
406}
407
408#[derive(Clone, Debug, PartialEq, Eq)]
409pub enum ThreadSummary {
410 Pending,
411 Generating,
412 Ready(SharedString),
413 Error,
414}
415
416impl ThreadSummary {
417 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
418
419 pub fn or_default(&self) -> SharedString {
420 self.unwrap_or(Self::DEFAULT)
421 }
422
423 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
424 self.ready().unwrap_or_else(|| message.into())
425 }
426
427 pub fn ready(&self) -> Option<SharedString> {
428 match self {
429 ThreadSummary::Ready(summary) => Some(summary.clone()),
430 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
436pub struct ExceededWindowError {
437 /// Model used when last message exceeded context window
438 model_id: LanguageModelId,
439 /// Token count including last message
440 token_count: u64,
441}
442
443impl Thread {
444 pub fn new(
445 project: Entity<Project>,
446 tools: Entity<ToolWorkingSet>,
447 prompt_builder: Arc<PromptBuilder>,
448 system_prompt: SharedProjectContext,
449 cx: &mut Context<Self>,
450 ) -> Self {
451 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
452 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
453 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
454
455 Self {
456 id: ThreadId::new(),
457 updated_at: Utc::now(),
458 summary: ThreadSummary::Pending,
459 pending_summary: Task::ready(None),
460 detailed_summary_task: Task::ready(None),
461 detailed_summary_tx,
462 detailed_summary_rx,
463 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
464 messages: Vec::new(),
465 next_message_id: MessageId(0),
466 last_prompt_id: PromptId::new(),
467 project_context: system_prompt,
468 checkpoints_by_message: HashMap::default(),
469 completion_count: 0,
470 pending_completions: Vec::new(),
471 project: project.clone(),
472 prompt_builder,
473 tools: tools.clone(),
474 last_restore_checkpoint: None,
475 pending_checkpoint: None,
476 tool_use: ToolUseState::new(tools.clone()),
477 action_log: cx.new(|_| ActionLog::new(project.clone())),
478 initial_project_snapshot: {
479 let project_snapshot = Self::project_snapshot(project, cx);
480 cx.foreground_executor()
481 .spawn(async move { Some(project_snapshot.await) })
482 .shared()
483 },
484 request_token_usage: Vec::new(),
485 cumulative_token_usage: TokenUsage::default(),
486 exceeded_window_error: None,
487 tool_use_limit_reached: false,
488 retry_state: None,
489 message_feedback: HashMap::default(),
490 last_error_context: None,
491 last_received_chunk_at: None,
492 request_callback: None,
493 remaining_turns: u32::MAX,
494 configured_model,
495 profile: AgentProfile::new(profile_id, tools),
496 }
497 }
498
499 pub fn deserialize(
500 id: ThreadId,
501 serialized: SerializedThread,
502 project: Entity<Project>,
503 tools: Entity<ToolWorkingSet>,
504 prompt_builder: Arc<PromptBuilder>,
505 project_context: SharedProjectContext,
506 window: Option<&mut Window>, // None in headless mode
507 cx: &mut Context<Self>,
508 ) -> Self {
509 let next_message_id = MessageId(
510 serialized
511 .messages
512 .last()
513 .map(|message| message.id.0 + 1)
514 .unwrap_or(0),
515 );
516 let tool_use = ToolUseState::from_serialized_messages(
517 tools.clone(),
518 &serialized.messages,
519 project.clone(),
520 window,
521 cx,
522 );
523 let (detailed_summary_tx, detailed_summary_rx) =
524 postage::watch::channel_with(serialized.detailed_summary_state);
525
526 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
527 serialized
528 .model
529 .and_then(|model| {
530 let model = SelectedModel {
531 provider: model.provider.clone().into(),
532 model: model.model.into(),
533 };
534 registry.select_model(&model, cx)
535 })
536 .or_else(|| registry.default_model())
537 });
538
539 let completion_mode = serialized
540 .completion_mode
541 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
542 let profile_id = serialized
543 .profile
544 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
545
546 Self {
547 id,
548 updated_at: serialized.updated_at,
549 summary: ThreadSummary::Ready(serialized.summary),
550 pending_summary: Task::ready(None),
551 detailed_summary_task: Task::ready(None),
552 detailed_summary_tx,
553 detailed_summary_rx,
554 completion_mode,
555 retry_state: None,
556 messages: serialized
557 .messages
558 .into_iter()
559 .map(|message| Message {
560 id: message.id,
561 role: message.role,
562 segments: message
563 .segments
564 .into_iter()
565 .map(|segment| match segment {
566 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
567 SerializedMessageSegment::Thinking { text, signature } => {
568 MessageSegment::Thinking { text, signature }
569 }
570 SerializedMessageSegment::RedactedThinking { data } => {
571 MessageSegment::RedactedThinking(data)
572 }
573 })
574 .collect(),
575 loaded_context: LoadedContext {
576 contexts: Vec::new(),
577 text: message.context,
578 images: Vec::new(),
579 },
580 creases: message
581 .creases
582 .into_iter()
583 .map(|crease| MessageCrease {
584 range: crease.start..crease.end,
585 icon_path: crease.icon_path,
586 label: crease.label,
587 context: None,
588 })
589 .collect(),
590 is_hidden: message.is_hidden,
591 ui_only: false, // UI-only messages are not persisted
592 })
593 .collect(),
594 next_message_id,
595 last_prompt_id: PromptId::new(),
596 project_context,
597 checkpoints_by_message: HashMap::default(),
598 completion_count: 0,
599 pending_completions: Vec::new(),
600 last_restore_checkpoint: None,
601 pending_checkpoint: None,
602 project: project.clone(),
603 prompt_builder,
604 tools: tools.clone(),
605 tool_use,
606 action_log: cx.new(|_| ActionLog::new(project)),
607 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
608 request_token_usage: serialized.request_token_usage,
609 cumulative_token_usage: serialized.cumulative_token_usage,
610 exceeded_window_error: None,
611 tool_use_limit_reached: serialized.tool_use_limit_reached,
612 message_feedback: HashMap::default(),
613 last_error_context: None,
614 last_received_chunk_at: None,
615 request_callback: None,
616 remaining_turns: u32::MAX,
617 configured_model,
618 profile: AgentProfile::new(profile_id, tools),
619 }
620 }
621
622 pub fn set_request_callback(
623 &mut self,
624 callback: impl 'static
625 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
626 ) {
627 self.request_callback = Some(Box::new(callback));
628 }
629
630 pub fn id(&self) -> &ThreadId {
631 &self.id
632 }
633
634 pub fn profile(&self) -> &AgentProfile {
635 &self.profile
636 }
637
638 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
639 if &id != self.profile.id() {
640 self.profile = AgentProfile::new(id, self.tools.clone());
641 cx.emit(ThreadEvent::ProfileChanged);
642 }
643 }
644
645 pub fn is_empty(&self) -> bool {
646 self.messages.is_empty()
647 }
648
649 pub fn updated_at(&self) -> DateTime<Utc> {
650 self.updated_at
651 }
652
653 pub fn touch_updated_at(&mut self) {
654 self.updated_at = Utc::now();
655 }
656
657 pub fn advance_prompt_id(&mut self) {
658 self.last_prompt_id = PromptId::new();
659 }
660
661 pub fn project_context(&self) -> SharedProjectContext {
662 self.project_context.clone()
663 }
664
665 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
666 if self.configured_model.is_none() {
667 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
668 }
669 self.configured_model.clone()
670 }
671
672 pub fn configured_model(&self) -> Option<ConfiguredModel> {
673 self.configured_model.clone()
674 }
675
676 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
677 self.configured_model = model;
678 cx.notify();
679 }
680
681 pub fn summary(&self) -> &ThreadSummary {
682 &self.summary
683 }
684
685 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
686 let current_summary = match &self.summary {
687 ThreadSummary::Pending | ThreadSummary::Generating => return,
688 ThreadSummary::Ready(summary) => summary,
689 ThreadSummary::Error => &ThreadSummary::DEFAULT,
690 };
691
692 let mut new_summary = new_summary.into();
693
694 if new_summary.is_empty() {
695 new_summary = ThreadSummary::DEFAULT;
696 }
697
698 if current_summary != &new_summary {
699 self.summary = ThreadSummary::Ready(new_summary);
700 cx.emit(ThreadEvent::SummaryChanged);
701 }
702 }
703
704 pub fn completion_mode(&self) -> CompletionMode {
705 self.completion_mode
706 }
707
708 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
709 self.completion_mode = mode;
710 }
711
712 pub fn message(&self, id: MessageId) -> Option<&Message> {
713 let index = self
714 .messages
715 .binary_search_by(|message| message.id.cmp(&id))
716 .ok()?;
717
718 self.messages.get(index)
719 }
720
721 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
722 self.messages.iter()
723 }
724
725 pub fn is_generating(&self) -> bool {
726 !self.pending_completions.is_empty() || !self.all_tools_finished()
727 }
728
729 /// Indicates whether streaming of language model events is stale.
730 /// When `is_generating()` is false, this method returns `None`.
731 pub fn is_generation_stale(&self) -> Option<bool> {
732 const STALE_THRESHOLD: u128 = 250;
733
734 self.last_received_chunk_at
735 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
736 }
737
738 fn received_chunk(&mut self) {
739 self.last_received_chunk_at = Some(Instant::now());
740 }
741
742 pub fn queue_state(&self) -> Option<QueueState> {
743 self.pending_completions
744 .first()
745 .map(|pending_completion| pending_completion.queue_state)
746 }
747
748 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
749 &self.tools
750 }
751
752 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
753 self.tool_use
754 .pending_tool_uses()
755 .into_iter()
756 .find(|tool_use| &tool_use.id == id)
757 }
758
759 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
760 self.tool_use
761 .pending_tool_uses()
762 .into_iter()
763 .filter(|tool_use| tool_use.status.needs_confirmation())
764 }
765
766 pub fn has_pending_tool_uses(&self) -> bool {
767 !self.tool_use.pending_tool_uses().is_empty()
768 }
769
770 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
771 self.checkpoints_by_message.get(&id).cloned()
772 }
773
774 pub fn restore_checkpoint(
775 &mut self,
776 checkpoint: ThreadCheckpoint,
777 cx: &mut Context<Self>,
778 ) -> Task<Result<()>> {
779 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
780 message_id: checkpoint.message_id,
781 });
782 cx.emit(ThreadEvent::CheckpointChanged);
783 cx.notify();
784
785 let git_store = self.project().read(cx).git_store().clone();
786 let restore = git_store.update(cx, |git_store, cx| {
787 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
788 });
789
790 cx.spawn(async move |this, cx| {
791 let result = restore.await;
792 this.update(cx, |this, cx| {
793 if let Err(err) = result.as_ref() {
794 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
795 message_id: checkpoint.message_id,
796 error: err.to_string(),
797 });
798 } else {
799 this.truncate(checkpoint.message_id, cx);
800 this.last_restore_checkpoint = None;
801 }
802 this.pending_checkpoint = None;
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 })?;
806 result
807 })
808 }
809
810 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
811 let pending_checkpoint = if self.is_generating() {
812 return;
813 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
814 checkpoint
815 } else {
816 return;
817 };
818
819 self.finalize_checkpoint(pending_checkpoint, cx);
820 }
821
822 fn finalize_checkpoint(
823 &mut self,
824 pending_checkpoint: ThreadCheckpoint,
825 cx: &mut Context<Self>,
826 ) {
827 let git_store = self.project.read(cx).git_store().clone();
828 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
829 cx.spawn(async move |this, cx| match final_checkpoint.await {
830 Ok(final_checkpoint) => {
831 let equal = git_store
832 .update(cx, |store, cx| {
833 store.compare_checkpoints(
834 pending_checkpoint.git_checkpoint.clone(),
835 final_checkpoint.clone(),
836 cx,
837 )
838 })?
839 .await
840 .unwrap_or(false);
841
842 this.update(cx, |this, cx| {
843 this.pending_checkpoint = if equal {
844 Some(pending_checkpoint)
845 } else {
846 this.insert_checkpoint(pending_checkpoint, cx);
847 Some(ThreadCheckpoint {
848 message_id: this.next_message_id,
849 git_checkpoint: final_checkpoint,
850 })
851 }
852 })?;
853
854 Ok(())
855 }
856 Err(_) => this.update(cx, |this, cx| {
857 this.insert_checkpoint(pending_checkpoint, cx)
858 }),
859 })
860 .detach();
861 }
862
863 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
864 self.checkpoints_by_message
865 .insert(checkpoint.message_id, checkpoint);
866 cx.emit(ThreadEvent::CheckpointChanged);
867 cx.notify();
868 }
869
870 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
871 self.last_restore_checkpoint.as_ref()
872 }
873
874 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
875 let Some(message_ix) = self
876 .messages
877 .iter()
878 .rposition(|message| message.id == message_id)
879 else {
880 return;
881 };
882 for deleted_message in self.messages.drain(message_ix..) {
883 self.checkpoints_by_message.remove(&deleted_message.id);
884 }
885 cx.notify();
886 }
887
888 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
889 self.messages
890 .iter()
891 .find(|message| message.id == id)
892 .into_iter()
893 .flat_map(|message| message.loaded_context.contexts.iter())
894 }
895
896 pub fn is_turn_end(&self, ix: usize) -> bool {
897 if self.messages.is_empty() {
898 return false;
899 }
900
901 if !self.is_generating() && ix == self.messages.len() - 1 {
902 return true;
903 }
904
905 let Some(message) = self.messages.get(ix) else {
906 return false;
907 };
908
909 if message.role != Role::Assistant {
910 return false;
911 }
912
913 self.messages
914 .get(ix + 1)
915 .and_then(|message| {
916 self.message(message.id)
917 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
918 })
919 .unwrap_or(false)
920 }
921
922 pub fn tool_use_limit_reached(&self) -> bool {
923 self.tool_use_limit_reached
924 }
925
926 /// Returns whether all of the tool uses have finished running.
927 pub fn all_tools_finished(&self) -> bool {
928 // If the only pending tool uses left are the ones with errors, then
929 // that means that we've finished running all of the pending tools.
930 self.tool_use
931 .pending_tool_uses()
932 .iter()
933 .all(|pending_tool_use| pending_tool_use.status.is_error())
934 }
935
936 /// Returns whether any pending tool uses may perform edits
937 pub fn has_pending_edit_tool_uses(&self) -> bool {
938 self.tool_use
939 .pending_tool_uses()
940 .iter()
941 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
942 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
943 }
944
945 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
946 self.tool_use.tool_uses_for_message(id, &self.project, cx)
947 }
948
949 pub fn tool_results_for_message(
950 &self,
951 assistant_message_id: MessageId,
952 ) -> Vec<&LanguageModelToolResult> {
953 self.tool_use.tool_results_for_message(assistant_message_id)
954 }
955
956 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
957 self.tool_use.tool_result(id)
958 }
959
960 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
961 match &self.tool_use.tool_result(id)?.content {
962 LanguageModelToolResultContent::Text(text) => Some(text),
963 LanguageModelToolResultContent::Image(_) => {
964 // TODO: We should display image
965 None
966 }
967 }
968 }
969
970 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
971 self.tool_use.tool_result_card(id).cloned()
972 }
973
974 /// Return tools that are both enabled and supported by the model
975 pub fn available_tools(
976 &self,
977 cx: &App,
978 model: Arc<dyn LanguageModel>,
979 ) -> Vec<LanguageModelRequestTool> {
980 if model.supports_tools() {
981 self.profile
982 .enabled_tools(cx)
983 .into_iter()
984 .filter_map(|(name, tool)| {
985 // Skip tools that cannot be supported
986 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
987 Some(LanguageModelRequestTool {
988 name: name.into(),
989 description: tool.description(),
990 input_schema,
991 })
992 })
993 .collect()
994 } else {
995 Vec::default()
996 }
997 }
998
999 pub fn insert_user_message(
1000 &mut self,
1001 text: impl Into<String>,
1002 loaded_context: ContextLoadResult,
1003 git_checkpoint: Option<GitStoreCheckpoint>,
1004 creases: Vec<MessageCrease>,
1005 cx: &mut Context<Self>,
1006 ) -> MessageId {
1007 if !loaded_context.referenced_buffers.is_empty() {
1008 self.action_log.update(cx, |log, cx| {
1009 for buffer in loaded_context.referenced_buffers {
1010 log.buffer_read(buffer, cx);
1011 }
1012 });
1013 }
1014
1015 let message_id = self.insert_message(
1016 Role::User,
1017 vec![MessageSegment::Text(text.into())],
1018 loaded_context.loaded_context,
1019 creases,
1020 false,
1021 cx,
1022 );
1023
1024 if let Some(git_checkpoint) = git_checkpoint {
1025 self.pending_checkpoint = Some(ThreadCheckpoint {
1026 message_id,
1027 git_checkpoint,
1028 });
1029 }
1030
1031 message_id
1032 }
1033
1034 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1035 let id = self.insert_message(
1036 Role::User,
1037 vec![MessageSegment::Text("Continue where you left off".into())],
1038 LoadedContext::default(),
1039 vec![],
1040 true,
1041 cx,
1042 );
1043 self.pending_checkpoint = None;
1044
1045 id
1046 }
1047
1048 pub fn insert_assistant_message(
1049 &mut self,
1050 segments: Vec<MessageSegment>,
1051 cx: &mut Context<Self>,
1052 ) -> MessageId {
1053 self.insert_message(
1054 Role::Assistant,
1055 segments,
1056 LoadedContext::default(),
1057 Vec::new(),
1058 false,
1059 cx,
1060 )
1061 }
1062
1063 pub fn insert_message(
1064 &mut self,
1065 role: Role,
1066 segments: Vec<MessageSegment>,
1067 loaded_context: LoadedContext,
1068 creases: Vec<MessageCrease>,
1069 is_hidden: bool,
1070 cx: &mut Context<Self>,
1071 ) -> MessageId {
1072 let id = self.next_message_id.post_inc();
1073 self.messages.push(Message {
1074 id,
1075 role,
1076 segments,
1077 loaded_context,
1078 creases,
1079 is_hidden,
1080 ui_only: false,
1081 });
1082 self.touch_updated_at();
1083 cx.emit(ThreadEvent::MessageAdded(id));
1084 id
1085 }
1086
1087 pub fn edit_message(
1088 &mut self,
1089 id: MessageId,
1090 new_role: Role,
1091 new_segments: Vec<MessageSegment>,
1092 creases: Vec<MessageCrease>,
1093 loaded_context: Option<LoadedContext>,
1094 checkpoint: Option<GitStoreCheckpoint>,
1095 cx: &mut Context<Self>,
1096 ) -> bool {
1097 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1098 return false;
1099 };
1100 message.role = new_role;
1101 message.segments = new_segments;
1102 message.creases = creases;
1103 if let Some(context) = loaded_context {
1104 message.loaded_context = context;
1105 }
1106 if let Some(git_checkpoint) = checkpoint {
1107 self.checkpoints_by_message.insert(
1108 id,
1109 ThreadCheckpoint {
1110 message_id: id,
1111 git_checkpoint,
1112 },
1113 );
1114 }
1115 self.touch_updated_at();
1116 cx.emit(ThreadEvent::MessageEdited(id));
1117 true
1118 }
1119
1120 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1121 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1122 return false;
1123 };
1124 self.messages.remove(index);
1125 self.touch_updated_at();
1126 cx.emit(ThreadEvent::MessageDeleted(id));
1127 true
1128 }
1129
1130 /// Returns the representation of this [`Thread`] in a textual form.
1131 ///
1132 /// This is the representation we use when attaching a thread as context to another thread.
1133 pub fn text(&self) -> String {
1134 let mut text = String::new();
1135
1136 for message in &self.messages {
1137 text.push_str(match message.role {
1138 language_model::Role::User => "User:",
1139 language_model::Role::Assistant => "Agent:",
1140 language_model::Role::System => "System:",
1141 });
1142 text.push('\n');
1143
1144 for segment in &message.segments {
1145 match segment {
1146 MessageSegment::Text(content) => text.push_str(content),
1147 MessageSegment::Thinking { text: content, .. } => {
1148 text.push_str(&format!("<think>{}</think>", content))
1149 }
1150 MessageSegment::RedactedThinking(_) => {}
1151 }
1152 }
1153 text.push('\n');
1154 }
1155
1156 text
1157 }
1158
1159 /// Serializes this thread into a format for storage or telemetry.
1160 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1161 let initial_project_snapshot = self.initial_project_snapshot.clone();
1162 cx.spawn(async move |this, cx| {
1163 let initial_project_snapshot = initial_project_snapshot.await;
1164 this.read_with(cx, |this, cx| SerializedThread {
1165 version: SerializedThread::VERSION.to_string(),
1166 summary: this.summary().or_default(),
1167 updated_at: this.updated_at(),
1168 messages: this
1169 .messages()
1170 .filter(|message| !message.ui_only)
1171 .map(|message| SerializedMessage {
1172 id: message.id,
1173 role: message.role,
1174 segments: message
1175 .segments
1176 .iter()
1177 .map(|segment| match segment {
1178 MessageSegment::Text(text) => {
1179 SerializedMessageSegment::Text { text: text.clone() }
1180 }
1181 MessageSegment::Thinking { text, signature } => {
1182 SerializedMessageSegment::Thinking {
1183 text: text.clone(),
1184 signature: signature.clone(),
1185 }
1186 }
1187 MessageSegment::RedactedThinking(data) => {
1188 SerializedMessageSegment::RedactedThinking {
1189 data: data.clone(),
1190 }
1191 }
1192 })
1193 .collect(),
1194 tool_uses: this
1195 .tool_uses_for_message(message.id, cx)
1196 .into_iter()
1197 .map(|tool_use| SerializedToolUse {
1198 id: tool_use.id,
1199 name: tool_use.name,
1200 input: tool_use.input,
1201 })
1202 .collect(),
1203 tool_results: this
1204 .tool_results_for_message(message.id)
1205 .into_iter()
1206 .map(|tool_result| SerializedToolResult {
1207 tool_use_id: tool_result.tool_use_id.clone(),
1208 is_error: tool_result.is_error,
1209 content: tool_result.content.clone(),
1210 output: tool_result.output.clone(),
1211 })
1212 .collect(),
1213 context: message.loaded_context.text.clone(),
1214 creases: message
1215 .creases
1216 .iter()
1217 .map(|crease| SerializedCrease {
1218 start: crease.range.start,
1219 end: crease.range.end,
1220 icon_path: crease.icon_path.clone(),
1221 label: crease.label.clone(),
1222 })
1223 .collect(),
1224 is_hidden: message.is_hidden,
1225 })
1226 .collect(),
1227 initial_project_snapshot,
1228 cumulative_token_usage: this.cumulative_token_usage,
1229 request_token_usage: this.request_token_usage.clone(),
1230 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1231 exceeded_window_error: this.exceeded_window_error.clone(),
1232 model: this
1233 .configured_model
1234 .as_ref()
1235 .map(|model| SerializedLanguageModel {
1236 provider: model.provider.id().0.to_string(),
1237 model: model.model.id().0.to_string(),
1238 }),
1239 completion_mode: Some(this.completion_mode),
1240 tool_use_limit_reached: this.tool_use_limit_reached,
1241 profile: Some(this.profile.id().clone()),
1242 })
1243 })
1244 }
1245
1246 pub fn remaining_turns(&self) -> u32 {
1247 self.remaining_turns
1248 }
1249
1250 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1251 self.remaining_turns = remaining_turns;
1252 }
1253
1254 pub fn send_to_model(
1255 &mut self,
1256 model: Arc<dyn LanguageModel>,
1257 intent: CompletionIntent,
1258 window: Option<AnyWindowHandle>,
1259 cx: &mut Context<Self>,
1260 ) {
1261 if self.remaining_turns == 0 {
1262 return;
1263 }
1264
1265 self.remaining_turns -= 1;
1266
1267 self.flush_notifications(model.clone(), intent, cx);
1268
1269 let _checkpoint = self.finalize_pending_checkpoint(cx);
1270 self.stream_completion(
1271 self.to_completion_request(model.clone(), intent, cx),
1272 model,
1273 intent,
1274 window,
1275 cx,
1276 );
1277 }
1278
1279 pub fn retry_last_completion(
1280 &mut self,
1281 window: Option<AnyWindowHandle>,
1282 cx: &mut Context<Self>,
1283 ) {
1284 // Clear any existing error state
1285 self.retry_state = None;
1286
1287 // Use the last error context if available, otherwise fall back to configured model
1288 let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() {
1289 (model, intent)
1290 } else if let Some(configured_model) = self.configured_model.as_ref() {
1291 let model = configured_model.model.clone();
1292 let intent = if self.has_pending_tool_uses() {
1293 CompletionIntent::ToolResults
1294 } else {
1295 CompletionIntent::UserPrompt
1296 };
1297 (model, intent)
1298 } else if let Some(configured_model) = self.get_or_init_configured_model(cx) {
1299 let model = configured_model.model.clone();
1300 let intent = if self.has_pending_tool_uses() {
1301 CompletionIntent::ToolResults
1302 } else {
1303 CompletionIntent::UserPrompt
1304 };
1305 (model, intent)
1306 } else {
1307 return;
1308 };
1309
1310 self.send_to_model(model, intent, window, cx);
1311 }
1312
1313 pub fn enable_burn_mode_and_retry(
1314 &mut self,
1315 window: Option<AnyWindowHandle>,
1316 cx: &mut Context<Self>,
1317 ) {
1318 self.completion_mode = CompletionMode::Burn;
1319 cx.emit(ThreadEvent::ProfileChanged);
1320 self.retry_last_completion(window, cx);
1321 }
1322
1323 pub fn used_tools_since_last_user_message(&self) -> bool {
1324 for message in self.messages.iter().rev() {
1325 if self.tool_use.message_has_tool_results(message.id) {
1326 return true;
1327 } else if message.role == Role::User {
1328 return false;
1329 }
1330 }
1331
1332 false
1333 }
1334
1335 pub fn to_completion_request(
1336 &self,
1337 model: Arc<dyn LanguageModel>,
1338 intent: CompletionIntent,
1339 cx: &mut Context<Self>,
1340 ) -> LanguageModelRequest {
1341 let mut request = LanguageModelRequest {
1342 thread_id: Some(self.id.to_string()),
1343 prompt_id: Some(self.last_prompt_id.to_string()),
1344 intent: Some(intent),
1345 mode: None,
1346 messages: vec![],
1347 tools: Vec::new(),
1348 tool_choice: None,
1349 stop: Vec::new(),
1350 temperature: AgentSettings::temperature_for_model(&model, cx),
1351 thinking_allowed: true,
1352 };
1353
1354 let available_tools = self.available_tools(cx, model.clone());
1355 let available_tool_names = available_tools
1356 .iter()
1357 .map(|tool| tool.name.clone())
1358 .collect();
1359
1360 let model_context = &ModelContext {
1361 available_tools: available_tool_names,
1362 };
1363
1364 if let Some(project_context) = self.project_context.borrow().as_ref() {
1365 match self
1366 .prompt_builder
1367 .generate_assistant_system_prompt(project_context, model_context)
1368 {
1369 Err(err) => {
1370 let message = format!("{err:?}").into();
1371 log::error!("{message}");
1372 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1373 header: "Error generating system prompt".into(),
1374 message,
1375 }));
1376 }
1377 Ok(system_prompt) => {
1378 request.messages.push(LanguageModelRequestMessage {
1379 role: Role::System,
1380 content: vec![MessageContent::Text(system_prompt)],
1381 cache: true,
1382 });
1383 }
1384 }
1385 } else {
1386 let message = "Context for system prompt unexpectedly not ready.".into();
1387 log::error!("{message}");
1388 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1389 header: "Error generating system prompt".into(),
1390 message,
1391 }));
1392 }
1393
1394 let mut message_ix_to_cache = None;
1395 for message in &self.messages {
1396 // ui_only messages are for the UI only, not for the model
1397 if message.ui_only {
1398 continue;
1399 }
1400
1401 let mut request_message = LanguageModelRequestMessage {
1402 role: message.role,
1403 content: Vec::new(),
1404 cache: false,
1405 };
1406
1407 message
1408 .loaded_context
1409 .add_to_request_message(&mut request_message);
1410
1411 for segment in &message.segments {
1412 match segment {
1413 MessageSegment::Text(text) => {
1414 let text = text.trim_end();
1415 if !text.is_empty() {
1416 request_message
1417 .content
1418 .push(MessageContent::Text(text.into()));
1419 }
1420 }
1421 MessageSegment::Thinking { text, signature } => {
1422 if !text.is_empty() {
1423 request_message.content.push(MessageContent::Thinking {
1424 text: text.into(),
1425 signature: signature.clone(),
1426 });
1427 }
1428 }
1429 MessageSegment::RedactedThinking(data) => {
1430 request_message
1431 .content
1432 .push(MessageContent::RedactedThinking(data.clone()));
1433 }
1434 };
1435 }
1436
1437 let mut cache_message = true;
1438 let mut tool_results_message = LanguageModelRequestMessage {
1439 role: Role::User,
1440 content: Vec::new(),
1441 cache: false,
1442 };
1443 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1444 if let Some(tool_result) = tool_result {
1445 request_message
1446 .content
1447 .push(MessageContent::ToolUse(tool_use.clone()));
1448 tool_results_message
1449 .content
1450 .push(MessageContent::ToolResult(LanguageModelToolResult {
1451 tool_use_id: tool_use.id.clone(),
1452 tool_name: tool_result.tool_name.clone(),
1453 is_error: tool_result.is_error,
1454 content: if tool_result.content.is_empty() {
1455 // Surprisingly, the API fails if we return an empty string here.
1456 // It thinks we are sending a tool use without a tool result.
1457 "<Tool returned an empty string>".into()
1458 } else {
1459 tool_result.content.clone()
1460 },
1461 output: None,
1462 }));
1463 } else {
1464 cache_message = false;
1465 log::debug!(
1466 "skipped tool use {:?} because it is still pending",
1467 tool_use
1468 );
1469 }
1470 }
1471
1472 if cache_message {
1473 message_ix_to_cache = Some(request.messages.len());
1474 }
1475 request.messages.push(request_message);
1476
1477 if !tool_results_message.content.is_empty() {
1478 if cache_message {
1479 message_ix_to_cache = Some(request.messages.len());
1480 }
1481 request.messages.push(tool_results_message);
1482 }
1483 }
1484
1485 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1486 if let Some(message_ix_to_cache) = message_ix_to_cache {
1487 request.messages[message_ix_to_cache].cache = true;
1488 }
1489
1490 request.tools = available_tools;
1491 request.mode = if model.supports_burn_mode() {
1492 Some(self.completion_mode.into())
1493 } else {
1494 Some(CompletionMode::Normal.into())
1495 };
1496
1497 request
1498 }
1499
1500 fn to_summarize_request(
1501 &self,
1502 model: &Arc<dyn LanguageModel>,
1503 intent: CompletionIntent,
1504 added_user_message: String,
1505 cx: &App,
1506 ) -> LanguageModelRequest {
1507 let mut request = LanguageModelRequest {
1508 thread_id: None,
1509 prompt_id: None,
1510 intent: Some(intent),
1511 mode: None,
1512 messages: vec![],
1513 tools: Vec::new(),
1514 tool_choice: None,
1515 stop: Vec::new(),
1516 temperature: AgentSettings::temperature_for_model(model, cx),
1517 thinking_allowed: false,
1518 };
1519
1520 for message in &self.messages {
1521 let mut request_message = LanguageModelRequestMessage {
1522 role: message.role,
1523 content: Vec::new(),
1524 cache: false,
1525 };
1526
1527 for segment in &message.segments {
1528 match segment {
1529 MessageSegment::Text(text) => request_message
1530 .content
1531 .push(MessageContent::Text(text.clone())),
1532 MessageSegment::Thinking { .. } => {}
1533 MessageSegment::RedactedThinking(_) => {}
1534 }
1535 }
1536
1537 if request_message.content.is_empty() {
1538 continue;
1539 }
1540
1541 request.messages.push(request_message);
1542 }
1543
1544 request.messages.push(LanguageModelRequestMessage {
1545 role: Role::User,
1546 content: vec![MessageContent::Text(added_user_message)],
1547 cache: false,
1548 });
1549
1550 request
1551 }
1552
1553 /// Insert auto-generated notifications (if any) to the thread
1554 fn flush_notifications(
1555 &mut self,
1556 model: Arc<dyn LanguageModel>,
1557 intent: CompletionIntent,
1558 cx: &mut Context<Self>,
1559 ) {
1560 match intent {
1561 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1562 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1563 cx.emit(ThreadEvent::ToolFinished {
1564 tool_use_id: pending_tool_use.id.clone(),
1565 pending_tool_use: Some(pending_tool_use),
1566 });
1567 }
1568 }
1569 CompletionIntent::ThreadSummarization
1570 | CompletionIntent::ThreadContextSummarization
1571 | CompletionIntent::CreateFile
1572 | CompletionIntent::EditFile
1573 | CompletionIntent::InlineAssist
1574 | CompletionIntent::TerminalInlineAssist
1575 | CompletionIntent::GenerateGitCommitMessage => {}
1576 };
1577 }
1578
1579 fn attach_tracked_files_state(
1580 &mut self,
1581 model: Arc<dyn LanguageModel>,
1582 cx: &mut App,
1583 ) -> Option<PendingToolUse> {
1584 // Represent notification as a simulated `project_notifications` tool call
1585 let tool_name = Arc::from("project_notifications");
1586 let tool = self.tools.read(cx).tool(&tool_name, cx)?;
1587
1588 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1589 return None;
1590 }
1591
1592 if self
1593 .action_log
1594 .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
1595 {
1596 return None;
1597 }
1598
1599 let input = serde_json::json!({});
1600 let request = Arc::new(LanguageModelRequest::default()); // unused
1601 let window = None;
1602 let tool_result = tool.run(
1603 input,
1604 request,
1605 self.project.clone(),
1606 self.action_log.clone(),
1607 model.clone(),
1608 window,
1609 cx,
1610 );
1611
1612 let tool_use_id =
1613 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1614
1615 let tool_use = LanguageModelToolUse {
1616 id: tool_use_id.clone(),
1617 name: tool_name.clone(),
1618 raw_input: "{}".to_string(),
1619 input: serde_json::json!({}),
1620 is_input_complete: true,
1621 };
1622
1623 let tool_output = cx.background_executor().block(tool_result.output);
1624
1625 // Attach a project_notification tool call to the latest existing
1626 // Assistant message. We cannot create a new Assistant message
1627 // because thinking models require a `thinking` block that we
1628 // cannot mock. We cannot send a notification as a normal
1629 // (non-tool-use) User message because this distracts Agent
1630 // too much.
1631 let tool_message_id = self
1632 .messages
1633 .iter()
1634 .enumerate()
1635 .rfind(|(_, message)| message.role == Role::Assistant)
1636 .map(|(_, message)| message.id)?;
1637
1638 let tool_use_metadata = ToolUseMetadata {
1639 model: model.clone(),
1640 thread_id: self.id.clone(),
1641 prompt_id: self.last_prompt_id.clone(),
1642 };
1643
1644 self.tool_use
1645 .request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx);
1646
1647 self.tool_use.insert_tool_output(
1648 tool_use_id,
1649 tool_name,
1650 tool_output,
1651 self.configured_model.as_ref(),
1652 self.completion_mode,
1653 )
1654 }
1655
1656 pub fn stream_completion(
1657 &mut self,
1658 request: LanguageModelRequest,
1659 model: Arc<dyn LanguageModel>,
1660 intent: CompletionIntent,
1661 window: Option<AnyWindowHandle>,
1662 cx: &mut Context<Self>,
1663 ) {
1664 self.tool_use_limit_reached = false;
1665
1666 let pending_completion_id = post_inc(&mut self.completion_count);
1667 let mut request_callback_parameters = if self.request_callback.is_some() {
1668 Some((request.clone(), Vec::new()))
1669 } else {
1670 None
1671 };
1672 let prompt_id = self.last_prompt_id.clone();
1673 let tool_use_metadata = ToolUseMetadata {
1674 model: model.clone(),
1675 thread_id: self.id.clone(),
1676 prompt_id: prompt_id.clone(),
1677 };
1678
1679 let completion_mode = request
1680 .mode
1681 .unwrap_or(cloud_llm_client::CompletionMode::Normal);
1682
1683 self.last_received_chunk_at = Some(Instant::now());
1684
1685 let task = cx.spawn(async move |thread, cx| {
1686 let stream_completion_future = model.stream_completion(request, cx);
1687 let initial_token_usage =
1688 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1689 let stream_completion = async {
1690 let mut events = stream_completion_future.await?;
1691
1692 let mut stop_reason = StopReason::EndTurn;
1693 let mut current_token_usage = TokenUsage::default();
1694
1695 thread
1696 .update(cx, |_thread, cx| {
1697 cx.emit(ThreadEvent::NewRequest);
1698 })
1699 .ok();
1700
1701 let mut request_assistant_message_id = None;
1702
1703 while let Some(event) = events.next().await {
1704 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1705 response_events
1706 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1707 }
1708
1709 thread.update(cx, |thread, cx| {
1710 match event? {
1711 LanguageModelCompletionEvent::StartMessage { .. } => {
1712 request_assistant_message_id =
1713 Some(thread.insert_assistant_message(
1714 vec![MessageSegment::Text(String::new())],
1715 cx,
1716 ));
1717 }
1718 LanguageModelCompletionEvent::Stop(reason) => {
1719 stop_reason = reason;
1720 }
1721 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1722 thread.update_token_usage_at_last_message(token_usage);
1723 thread.cumulative_token_usage = thread.cumulative_token_usage
1724 + token_usage
1725 - current_token_usage;
1726 current_token_usage = token_usage;
1727 }
1728 LanguageModelCompletionEvent::Text(chunk) => {
1729 thread.received_chunk();
1730
1731 cx.emit(ThreadEvent::ReceivedTextChunk);
1732 if let Some(last_message) = thread.messages.last_mut() {
1733 if last_message.role == Role::Assistant
1734 && !thread.tool_use.has_tool_results(last_message.id)
1735 {
1736 last_message.push_text(&chunk);
1737 cx.emit(ThreadEvent::StreamedAssistantText(
1738 last_message.id,
1739 chunk,
1740 ));
1741 } else {
1742 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1743 // of a new Assistant response.
1744 //
1745 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1746 // will result in duplicating the text of the chunk in the rendered Markdown.
1747 request_assistant_message_id =
1748 Some(thread.insert_assistant_message(
1749 vec![MessageSegment::Text(chunk.to_string())],
1750 cx,
1751 ));
1752 };
1753 }
1754 }
1755 LanguageModelCompletionEvent::Thinking {
1756 text: chunk,
1757 signature,
1758 } => {
1759 thread.received_chunk();
1760
1761 if let Some(last_message) = thread.messages.last_mut() {
1762 if last_message.role == Role::Assistant
1763 && !thread.tool_use.has_tool_results(last_message.id)
1764 {
1765 last_message.push_thinking(&chunk, signature);
1766 cx.emit(ThreadEvent::StreamedAssistantThinking(
1767 last_message.id,
1768 chunk,
1769 ));
1770 } else {
1771 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1772 // of a new Assistant response.
1773 //
1774 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1775 // will result in duplicating the text of the chunk in the rendered Markdown.
1776 request_assistant_message_id =
1777 Some(thread.insert_assistant_message(
1778 vec![MessageSegment::Thinking {
1779 text: chunk.to_string(),
1780 signature,
1781 }],
1782 cx,
1783 ));
1784 };
1785 }
1786 }
1787 LanguageModelCompletionEvent::RedactedThinking { data } => {
1788 thread.received_chunk();
1789
1790 if let Some(last_message) = thread.messages.last_mut() {
1791 if last_message.role == Role::Assistant
1792 && !thread.tool_use.has_tool_results(last_message.id)
1793 {
1794 last_message.push_redacted_thinking(data);
1795 } else {
1796 request_assistant_message_id =
1797 Some(thread.insert_assistant_message(
1798 vec![MessageSegment::RedactedThinking(data)],
1799 cx,
1800 ));
1801 };
1802 }
1803 }
1804 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1805 let last_assistant_message_id = request_assistant_message_id
1806 .unwrap_or_else(|| {
1807 let new_assistant_message_id =
1808 thread.insert_assistant_message(vec![], cx);
1809 request_assistant_message_id =
1810 Some(new_assistant_message_id);
1811 new_assistant_message_id
1812 });
1813
1814 let tool_use_id = tool_use.id.clone();
1815 let streamed_input = if tool_use.is_input_complete {
1816 None
1817 } else {
1818 Some(tool_use.input.clone())
1819 };
1820
1821 let ui_text = thread.tool_use.request_tool_use(
1822 last_assistant_message_id,
1823 tool_use,
1824 tool_use_metadata.clone(),
1825 cx,
1826 );
1827
1828 if let Some(input) = streamed_input {
1829 cx.emit(ThreadEvent::StreamedToolUse {
1830 tool_use_id,
1831 ui_text,
1832 input,
1833 });
1834 }
1835 }
1836 LanguageModelCompletionEvent::ToolUseJsonParseError {
1837 id,
1838 tool_name,
1839 raw_input: invalid_input_json,
1840 json_parse_error,
1841 } => {
1842 thread.receive_invalid_tool_json(
1843 id,
1844 tool_name,
1845 invalid_input_json,
1846 json_parse_error,
1847 window,
1848 cx,
1849 );
1850 }
1851 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1852 if let Some(completion) = thread
1853 .pending_completions
1854 .iter_mut()
1855 .find(|completion| completion.id == pending_completion_id)
1856 {
1857 match status_update {
1858 CompletionRequestStatus::Queued { position } => {
1859 completion.queue_state =
1860 QueueState::Queued { position };
1861 }
1862 CompletionRequestStatus::Started => {
1863 completion.queue_state = QueueState::Started;
1864 }
1865 CompletionRequestStatus::Failed {
1866 code,
1867 message,
1868 request_id: _,
1869 retry_after,
1870 } => {
1871 return Err(
1872 LanguageModelCompletionError::from_cloud_failure(
1873 model.upstream_provider_name(),
1874 code,
1875 message,
1876 retry_after.map(Duration::from_secs_f64),
1877 ),
1878 );
1879 }
1880 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1881 thread.update_model_request_usage(
1882 amount as u32,
1883 limit,
1884 cx,
1885 );
1886 }
1887 CompletionRequestStatus::ToolUseLimitReached => {
1888 thread.tool_use_limit_reached = true;
1889 cx.emit(ThreadEvent::ToolUseLimitReached);
1890 }
1891 }
1892 }
1893 }
1894 }
1895
1896 thread.touch_updated_at();
1897 cx.emit(ThreadEvent::StreamedCompletion);
1898 cx.notify();
1899
1900 Ok(())
1901 })??;
1902
1903 smol::future::yield_now().await;
1904 }
1905
1906 thread.update(cx, |thread, cx| {
1907 thread.last_received_chunk_at = None;
1908 thread
1909 .pending_completions
1910 .retain(|completion| completion.id != pending_completion_id);
1911
1912 // If there is a response without tool use, summarize the message. Otherwise,
1913 // allow two tool uses before summarizing.
1914 if matches!(thread.summary, ThreadSummary::Pending)
1915 && thread.messages.len() >= 2
1916 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1917 {
1918 thread.summarize(cx);
1919 }
1920 })?;
1921
1922 anyhow::Ok(stop_reason)
1923 };
1924
1925 let result = stream_completion.await;
1926 let mut retry_scheduled = false;
1927
1928 thread
1929 .update(cx, |thread, cx| {
1930 thread.finalize_pending_checkpoint(cx);
1931 match result.as_ref() {
1932 Ok(stop_reason) => {
1933 match stop_reason {
1934 StopReason::ToolUse => {
1935 let tool_uses =
1936 thread.use_pending_tools(window, model.clone(), cx);
1937 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1938 }
1939 StopReason::EndTurn | StopReason::MaxTokens => {
1940 thread.project.update(cx, |project, cx| {
1941 project.set_agent_location(None, cx);
1942 });
1943 }
1944 StopReason::Refusal => {
1945 thread.project.update(cx, |project, cx| {
1946 project.set_agent_location(None, cx);
1947 });
1948
1949 // Remove the turn that was refused.
1950 //
1951 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1952 {
1953 let mut messages_to_remove = Vec::new();
1954
1955 for (ix, message) in
1956 thread.messages.iter().enumerate().rev()
1957 {
1958 messages_to_remove.push(message.id);
1959
1960 if message.role == Role::User {
1961 if ix == 0 {
1962 break;
1963 }
1964
1965 if let Some(prev_message) =
1966 thread.messages.get(ix - 1)
1967 && prev_message.role == Role::Assistant {
1968 break;
1969 }
1970 }
1971 }
1972
1973 for message_id in messages_to_remove {
1974 thread.delete_message(message_id, cx);
1975 }
1976 }
1977
1978 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1979 header: "Language model refusal".into(),
1980 message:
1981 "Model refused to generate content for safety reasons."
1982 .into(),
1983 }));
1984 }
1985 }
1986
1987 // We successfully completed, so cancel any remaining retries.
1988 thread.retry_state = None;
1989 }
1990 Err(error) => {
1991 thread.project.update(cx, |project, cx| {
1992 project.set_agent_location(None, cx);
1993 });
1994
1995 if error.is::<PaymentRequiredError>() {
1996 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1997 } else if let Some(error) =
1998 error.downcast_ref::<ModelRequestLimitReachedError>()
1999 {
2000 cx.emit(ThreadEvent::ShowError(
2001 ThreadError::ModelRequestLimitReached { plan: error.plan },
2002 ));
2003 } else if let Some(completion_error) =
2004 error.downcast_ref::<LanguageModelCompletionError>()
2005 {
2006 match &completion_error {
2007 LanguageModelCompletionError::PromptTooLarge {
2008 tokens, ..
2009 } => {
2010 let tokens = tokens.unwrap_or_else(|| {
2011 // We didn't get an exact token count from the API, so fall back on our estimate.
2012 thread
2013 .total_token_usage()
2014 .map(|usage| usage.total)
2015 .unwrap_or(0)
2016 // We know the context window was exceeded in practice, so if our estimate was
2017 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
2018 .max(
2019 model
2020 .max_token_count_for_mode(completion_mode)
2021 .saturating_add(1),
2022 )
2023 });
2024 thread.exceeded_window_error = Some(ExceededWindowError {
2025 model_id: model.id(),
2026 token_count: tokens,
2027 });
2028 cx.notify();
2029 }
2030 _ => {
2031 if let Some(retry_strategy) =
2032 Thread::get_retry_strategy(completion_error)
2033 {
2034 log::info!(
2035 "Retrying with {:?} for language model completion error {:?}",
2036 retry_strategy,
2037 completion_error
2038 );
2039
2040 retry_scheduled = thread
2041 .handle_retryable_error_with_delay(
2042 completion_error,
2043 Some(retry_strategy),
2044 model.clone(),
2045 intent,
2046 window,
2047 cx,
2048 );
2049 }
2050 }
2051 }
2052 }
2053
2054 if !retry_scheduled {
2055 thread.cancel_last_completion(window, cx);
2056 }
2057 }
2058 }
2059
2060 if !retry_scheduled {
2061 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2062 }
2063
2064 if let Some((request_callback, (request, response_events))) = thread
2065 .request_callback
2066 .as_mut()
2067 .zip(request_callback_parameters.as_ref())
2068 {
2069 request_callback(request, response_events);
2070 }
2071
2072 if let Ok(initial_usage) = initial_token_usage {
2073 let usage = thread.cumulative_token_usage - initial_usage;
2074
2075 telemetry::event!(
2076 "Assistant Thread Completion",
2077 thread_id = thread.id().to_string(),
2078 prompt_id = prompt_id,
2079 model = model.telemetry_id(),
2080 model_provider = model.provider_id().to_string(),
2081 input_tokens = usage.input_tokens,
2082 output_tokens = usage.output_tokens,
2083 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2084 cache_read_input_tokens = usage.cache_read_input_tokens,
2085 );
2086 }
2087 })
2088 .ok();
2089 });
2090
2091 self.pending_completions.push(PendingCompletion {
2092 id: pending_completion_id,
2093 queue_state: QueueState::Sending,
2094 _task: task,
2095 });
2096 }
2097
2098 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2099 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2100 println!("No thread summary model");
2101 return;
2102 };
2103
2104 if !model.provider.is_authenticated(cx) {
2105 return;
2106 }
2107
2108 let request = self.to_summarize_request(
2109 &model.model,
2110 CompletionIntent::ThreadSummarization,
2111 SUMMARIZE_THREAD_PROMPT.into(),
2112 cx,
2113 );
2114
2115 self.summary = ThreadSummary::Generating;
2116
2117 self.pending_summary = cx.spawn(async move |this, cx| {
2118 let result = async {
2119 let mut messages = model.model.stream_completion(request, cx).await?;
2120
2121 let mut new_summary = String::new();
2122 while let Some(event) = messages.next().await {
2123 let Ok(event) = event else {
2124 continue;
2125 };
2126 let text = match event {
2127 LanguageModelCompletionEvent::Text(text) => text,
2128 LanguageModelCompletionEvent::StatusUpdate(
2129 CompletionRequestStatus::UsageUpdated { amount, limit },
2130 ) => {
2131 this.update(cx, |thread, cx| {
2132 thread.update_model_request_usage(amount as u32, limit, cx);
2133 })?;
2134 continue;
2135 }
2136 _ => continue,
2137 };
2138
2139 let mut lines = text.lines();
2140 new_summary.extend(lines.next());
2141
2142 // Stop if the LLM generated multiple lines.
2143 if lines.next().is_some() {
2144 break;
2145 }
2146 }
2147
2148 anyhow::Ok(new_summary)
2149 }
2150 .await;
2151
2152 this.update(cx, |this, cx| {
2153 match result {
2154 Ok(new_summary) => {
2155 if new_summary.is_empty() {
2156 this.summary = ThreadSummary::Error;
2157 } else {
2158 this.summary = ThreadSummary::Ready(new_summary.into());
2159 }
2160 }
2161 Err(err) => {
2162 this.summary = ThreadSummary::Error;
2163 log::error!("Failed to generate thread summary: {}", err);
2164 }
2165 }
2166 cx.emit(ThreadEvent::SummaryGenerated);
2167 })
2168 .log_err()?;
2169
2170 Some(())
2171 });
2172 }
2173
2174 fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2175 use LanguageModelCompletionError::*;
2176
2177 // General strategy here:
2178 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2179 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2180 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2181 match error {
2182 HttpResponseError {
2183 status_code: StatusCode::TOO_MANY_REQUESTS,
2184 ..
2185 } => Some(RetryStrategy::ExponentialBackoff {
2186 initial_delay: BASE_RETRY_DELAY,
2187 max_attempts: MAX_RETRY_ATTEMPTS,
2188 }),
2189 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2190 Some(RetryStrategy::Fixed {
2191 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2192 max_attempts: MAX_RETRY_ATTEMPTS,
2193 })
2194 }
2195 UpstreamProviderError {
2196 status,
2197 retry_after,
2198 ..
2199 } => match *status {
2200 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2201 Some(RetryStrategy::Fixed {
2202 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2203 max_attempts: MAX_RETRY_ATTEMPTS,
2204 })
2205 }
2206 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2207 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2208 // Internal Server Error could be anything, retry up to 3 times.
2209 max_attempts: 3,
2210 }),
2211 status => {
2212 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2213 // but we frequently get them in practice. See https://http.dev/529
2214 if status.as_u16() == 529 {
2215 Some(RetryStrategy::Fixed {
2216 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2217 max_attempts: MAX_RETRY_ATTEMPTS,
2218 })
2219 } else {
2220 Some(RetryStrategy::Fixed {
2221 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2222 max_attempts: 2,
2223 })
2224 }
2225 }
2226 },
2227 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2228 delay: BASE_RETRY_DELAY,
2229 max_attempts: 3,
2230 }),
2231 ApiReadResponseError { .. }
2232 | HttpSend { .. }
2233 | DeserializeResponse { .. }
2234 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2235 delay: BASE_RETRY_DELAY,
2236 max_attempts: 3,
2237 }),
2238 // Retrying these errors definitely shouldn't help.
2239 HttpResponseError {
2240 status_code:
2241 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2242 ..
2243 }
2244 | AuthenticationError { .. }
2245 | PermissionError { .. }
2246 | NoApiKey { .. }
2247 | ApiEndpointNotFound { .. }
2248 | PromptTooLarge { .. } => None,
2249 // These errors might be transient, so retry them
2250 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2251 delay: BASE_RETRY_DELAY,
2252 max_attempts: 1,
2253 }),
2254 // Retry all other 4xx and 5xx errors once.
2255 HttpResponseError { status_code, .. }
2256 if status_code.is_client_error() || status_code.is_server_error() =>
2257 {
2258 Some(RetryStrategy::Fixed {
2259 delay: BASE_RETRY_DELAY,
2260 max_attempts: 3,
2261 })
2262 }
2263 Other(err)
2264 if err.is::<PaymentRequiredError>()
2265 || err.is::<ModelRequestLimitReachedError>() =>
2266 {
2267 // Retrying won't help for Payment Required or Model Request Limit errors (where
2268 // the user must upgrade to usage-based billing to get more requests, or else wait
2269 // for a significant amount of time for the request limit to reset).
2270 None
2271 }
2272 // Conservatively assume that any other errors are non-retryable
2273 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2274 delay: BASE_RETRY_DELAY,
2275 max_attempts: 2,
2276 }),
2277 }
2278 }
2279
2280 fn handle_retryable_error_with_delay(
2281 &mut self,
2282 error: &LanguageModelCompletionError,
2283 strategy: Option<RetryStrategy>,
2284 model: Arc<dyn LanguageModel>,
2285 intent: CompletionIntent,
2286 window: Option<AnyWindowHandle>,
2287 cx: &mut Context<Self>,
2288 ) -> bool {
2289 // Store context for the Retry button
2290 self.last_error_context = Some((model.clone(), intent));
2291
2292 // Only auto-retry if Burn Mode is enabled
2293 if self.completion_mode != CompletionMode::Burn {
2294 // Show error with retry options
2295 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2296 message: format!(
2297 "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
2298 error
2299 )
2300 .into(),
2301 can_enable_burn_mode: true,
2302 }));
2303 return false;
2304 }
2305
2306 let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2307 return false;
2308 };
2309
2310 let max_attempts = match &strategy {
2311 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2312 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2313 };
2314
2315 let retry_state = self.retry_state.get_or_insert(RetryState {
2316 attempt: 0,
2317 max_attempts,
2318 intent,
2319 });
2320
2321 retry_state.attempt += 1;
2322 let attempt = retry_state.attempt;
2323 let max_attempts = retry_state.max_attempts;
2324 let intent = retry_state.intent;
2325
2326 if attempt <= max_attempts {
2327 let delay = match &strategy {
2328 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2329 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2330 Duration::from_secs(delay_secs)
2331 }
2332 RetryStrategy::Fixed { delay, .. } => *delay,
2333 };
2334
2335 // Add a transient message to inform the user
2336 let delay_secs = delay.as_secs();
2337 let retry_message = if max_attempts == 1 {
2338 format!("{error}. Retrying in {delay_secs} seconds...")
2339 } else {
2340 format!(
2341 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2342 in {delay_secs} seconds..."
2343 )
2344 };
2345 log::warn!(
2346 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2347 in {delay_secs} seconds: {error:?}",
2348 );
2349
2350 // Add a UI-only message instead of a regular message
2351 let id = self.next_message_id.post_inc();
2352 self.messages.push(Message {
2353 id,
2354 role: Role::System,
2355 segments: vec![MessageSegment::Text(retry_message)],
2356 loaded_context: LoadedContext::default(),
2357 creases: Vec::new(),
2358 is_hidden: false,
2359 ui_only: true,
2360 });
2361 cx.emit(ThreadEvent::MessageAdded(id));
2362
2363 // Schedule the retry
2364 let thread_handle = cx.entity().downgrade();
2365
2366 cx.spawn(async move |_thread, cx| {
2367 cx.background_executor().timer(delay).await;
2368
2369 thread_handle
2370 .update(cx, |thread, cx| {
2371 // Retry the completion
2372 thread.send_to_model(model, intent, window, cx);
2373 })
2374 .log_err();
2375 })
2376 .detach();
2377
2378 true
2379 } else {
2380 // Max retries exceeded
2381 self.retry_state = None;
2382
2383 // Stop generating since we're giving up on retrying.
2384 self.pending_completions.clear();
2385
2386 // Show error alongside a Retry button, but no
2387 // Enable Burn Mode button (since it's already enabled)
2388 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2389 message: format!("Failed after retrying: {}", error).into(),
2390 can_enable_burn_mode: false,
2391 }));
2392
2393 false
2394 }
2395 }
2396
2397 pub fn start_generating_detailed_summary_if_needed(
2398 &mut self,
2399 thread_store: WeakEntity<ThreadStore>,
2400 cx: &mut Context<Self>,
2401 ) {
2402 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2403 return;
2404 };
2405
2406 match &*self.detailed_summary_rx.borrow() {
2407 DetailedSummaryState::Generating { message_id, .. }
2408 | DetailedSummaryState::Generated { message_id, .. }
2409 if *message_id == last_message_id =>
2410 {
2411 // Already up-to-date
2412 return;
2413 }
2414 _ => {}
2415 }
2416
2417 let Some(ConfiguredModel { model, provider }) =
2418 LanguageModelRegistry::read_global(cx).thread_summary_model()
2419 else {
2420 return;
2421 };
2422
2423 if !provider.is_authenticated(cx) {
2424 return;
2425 }
2426
2427 let request = self.to_summarize_request(
2428 &model,
2429 CompletionIntent::ThreadContextSummarization,
2430 SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
2431 cx,
2432 );
2433
2434 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2435 message_id: last_message_id,
2436 };
2437
2438 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2439 // be better to allow the old task to complete, but this would require logic for choosing
2440 // which result to prefer (the old task could complete after the new one, resulting in a
2441 // stale summary).
2442 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2443 let stream = model.stream_completion_text(request, cx);
2444 let Some(mut messages) = stream.await.log_err() else {
2445 thread
2446 .update(cx, |thread, _cx| {
2447 *thread.detailed_summary_tx.borrow_mut() =
2448 DetailedSummaryState::NotGenerated;
2449 })
2450 .ok()?;
2451 return None;
2452 };
2453
2454 let mut new_detailed_summary = String::new();
2455
2456 while let Some(chunk) = messages.stream.next().await {
2457 if let Some(chunk) = chunk.log_err() {
2458 new_detailed_summary.push_str(&chunk);
2459 }
2460 }
2461
2462 thread
2463 .update(cx, |thread, _cx| {
2464 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2465 text: new_detailed_summary.into(),
2466 message_id: last_message_id,
2467 };
2468 })
2469 .ok()?;
2470
2471 // Save thread so its summary can be reused later
2472 if let Some(thread) = thread.upgrade()
2473 && let Ok(Ok(save_task)) = cx.update(|cx| {
2474 thread_store
2475 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2476 })
2477 {
2478 save_task.await.log_err();
2479 }
2480
2481 Some(())
2482 });
2483 }
2484
2485 pub async fn wait_for_detailed_summary_or_text(
2486 this: &Entity<Self>,
2487 cx: &mut AsyncApp,
2488 ) -> Option<SharedString> {
2489 let mut detailed_summary_rx = this
2490 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2491 .ok()?;
2492 loop {
2493 match detailed_summary_rx.recv().await? {
2494 DetailedSummaryState::Generating { .. } => {}
2495 DetailedSummaryState::NotGenerated => {
2496 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2497 }
2498 DetailedSummaryState::Generated { text, .. } => return Some(text),
2499 }
2500 }
2501 }
2502
2503 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2504 self.detailed_summary_rx
2505 .borrow()
2506 .text()
2507 .unwrap_or_else(|| self.text().into())
2508 }
2509
2510 pub fn is_generating_detailed_summary(&self) -> bool {
2511 matches!(
2512 &*self.detailed_summary_rx.borrow(),
2513 DetailedSummaryState::Generating { .. }
2514 )
2515 }
2516
2517 pub fn use_pending_tools(
2518 &mut self,
2519 window: Option<AnyWindowHandle>,
2520 model: Arc<dyn LanguageModel>,
2521 cx: &mut Context<Self>,
2522 ) -> Vec<PendingToolUse> {
2523 let request =
2524 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2525 let pending_tool_uses = self
2526 .tool_use
2527 .pending_tool_uses()
2528 .into_iter()
2529 .filter(|tool_use| tool_use.status.is_idle())
2530 .cloned()
2531 .collect::<Vec<_>>();
2532
2533 for tool_use in pending_tool_uses.iter() {
2534 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2535 }
2536
2537 pending_tool_uses
2538 }
2539
2540 fn use_pending_tool(
2541 &mut self,
2542 tool_use: PendingToolUse,
2543 request: Arc<LanguageModelRequest>,
2544 model: Arc<dyn LanguageModel>,
2545 window: Option<AnyWindowHandle>,
2546 cx: &mut Context<Self>,
2547 ) {
2548 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2549 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2550 };
2551
2552 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2553 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2554 }
2555
2556 if tool.needs_confirmation(&tool_use.input, &self.project, cx)
2557 && !AgentSettings::get_global(cx).always_allow_tool_actions
2558 {
2559 self.tool_use.confirm_tool_use(
2560 tool_use.id,
2561 tool_use.ui_text,
2562 tool_use.input,
2563 request,
2564 tool,
2565 );
2566 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2567 } else {
2568 self.run_tool(
2569 tool_use.id,
2570 tool_use.ui_text,
2571 tool_use.input,
2572 request,
2573 tool,
2574 model,
2575 window,
2576 cx,
2577 );
2578 }
2579 }
2580
2581 pub fn handle_hallucinated_tool_use(
2582 &mut self,
2583 tool_use_id: LanguageModelToolUseId,
2584 hallucinated_tool_name: Arc<str>,
2585 window: Option<AnyWindowHandle>,
2586 cx: &mut Context<Thread>,
2587 ) {
2588 let available_tools = self.profile.enabled_tools(cx);
2589
2590 let tool_list = available_tools
2591 .iter()
2592 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2593 .collect::<Vec<_>>()
2594 .join("\n");
2595
2596 let error_message = format!(
2597 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2598 hallucinated_tool_name, tool_list
2599 );
2600
2601 let pending_tool_use = self.tool_use.insert_tool_output(
2602 tool_use_id.clone(),
2603 hallucinated_tool_name,
2604 Err(anyhow!("Missing tool call: {error_message}")),
2605 self.configured_model.as_ref(),
2606 self.completion_mode,
2607 );
2608
2609 cx.emit(ThreadEvent::MissingToolUse {
2610 tool_use_id: tool_use_id.clone(),
2611 ui_text: error_message.into(),
2612 });
2613
2614 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2615 }
2616
2617 pub fn receive_invalid_tool_json(
2618 &mut self,
2619 tool_use_id: LanguageModelToolUseId,
2620 tool_name: Arc<str>,
2621 invalid_json: Arc<str>,
2622 error: String,
2623 window: Option<AnyWindowHandle>,
2624 cx: &mut Context<Thread>,
2625 ) {
2626 log::error!("The model returned invalid input JSON: {invalid_json}");
2627
2628 let pending_tool_use = self.tool_use.insert_tool_output(
2629 tool_use_id.clone(),
2630 tool_name,
2631 Err(anyhow!("Error parsing input JSON: {error}")),
2632 self.configured_model.as_ref(),
2633 self.completion_mode,
2634 );
2635 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2636 pending_tool_use.ui_text.clone()
2637 } else {
2638 log::error!(
2639 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2640 );
2641 format!("Unknown tool {}", tool_use_id).into()
2642 };
2643
2644 cx.emit(ThreadEvent::InvalidToolInput {
2645 tool_use_id: tool_use_id.clone(),
2646 ui_text,
2647 invalid_input_json: invalid_json,
2648 });
2649
2650 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2651 }
2652
2653 pub fn run_tool(
2654 &mut self,
2655 tool_use_id: LanguageModelToolUseId,
2656 ui_text: impl Into<SharedString>,
2657 input: serde_json::Value,
2658 request: Arc<LanguageModelRequest>,
2659 tool: Arc<dyn Tool>,
2660 model: Arc<dyn LanguageModel>,
2661 window: Option<AnyWindowHandle>,
2662 cx: &mut Context<Thread>,
2663 ) {
2664 let task =
2665 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2666 self.tool_use
2667 .run_pending_tool(tool_use_id, ui_text.into(), task);
2668 }
2669
2670 fn spawn_tool_use(
2671 &mut self,
2672 tool_use_id: LanguageModelToolUseId,
2673 request: Arc<LanguageModelRequest>,
2674 input: serde_json::Value,
2675 tool: Arc<dyn Tool>,
2676 model: Arc<dyn LanguageModel>,
2677 window: Option<AnyWindowHandle>,
2678 cx: &mut Context<Thread>,
2679 ) -> Task<()> {
2680 let tool_name: Arc<str> = tool.name().into();
2681
2682 let tool_result = tool.run(
2683 input,
2684 request,
2685 self.project.clone(),
2686 self.action_log.clone(),
2687 model,
2688 window,
2689 cx,
2690 );
2691
2692 // Store the card separately if it exists
2693 if let Some(card) = tool_result.card.clone() {
2694 self.tool_use
2695 .insert_tool_result_card(tool_use_id.clone(), card);
2696 }
2697
2698 cx.spawn({
2699 async move |thread: WeakEntity<Thread>, cx| {
2700 let output = tool_result.output.await;
2701
2702 thread
2703 .update(cx, |thread, cx| {
2704 let pending_tool_use = thread.tool_use.insert_tool_output(
2705 tool_use_id.clone(),
2706 tool_name,
2707 output,
2708 thread.configured_model.as_ref(),
2709 thread.completion_mode,
2710 );
2711 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2712 })
2713 .ok();
2714 }
2715 })
2716 }
2717
2718 fn tool_finished(
2719 &mut self,
2720 tool_use_id: LanguageModelToolUseId,
2721 pending_tool_use: Option<PendingToolUse>,
2722 canceled: bool,
2723 window: Option<AnyWindowHandle>,
2724 cx: &mut Context<Self>,
2725 ) {
2726 if self.all_tools_finished()
2727 && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
2728 && !canceled
2729 {
2730 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2731 }
2732
2733 cx.emit(ThreadEvent::ToolFinished {
2734 tool_use_id,
2735 pending_tool_use,
2736 });
2737 }
2738
2739 /// Cancels the last pending completion, if there are any pending.
2740 ///
2741 /// Returns whether a completion was canceled.
2742 pub fn cancel_last_completion(
2743 &mut self,
2744 window: Option<AnyWindowHandle>,
2745 cx: &mut Context<Self>,
2746 ) -> bool {
2747 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2748
2749 self.retry_state = None;
2750
2751 for pending_tool_use in self.tool_use.cancel_pending() {
2752 canceled = true;
2753 self.tool_finished(
2754 pending_tool_use.id.clone(),
2755 Some(pending_tool_use),
2756 true,
2757 window,
2758 cx,
2759 );
2760 }
2761
2762 if canceled {
2763 cx.emit(ThreadEvent::CompletionCanceled);
2764
2765 // When canceled, we always want to insert the checkpoint.
2766 // (We skip over finalize_pending_checkpoint, because it
2767 // would conclude we didn't have anything to insert here.)
2768 if let Some(checkpoint) = self.pending_checkpoint.take() {
2769 self.insert_checkpoint(checkpoint, cx);
2770 }
2771 } else {
2772 self.finalize_pending_checkpoint(cx);
2773 }
2774
2775 canceled
2776 }
2777
2778 /// Signals that any in-progress editing should be canceled.
2779 ///
2780 /// This method is used to notify listeners (like ActiveThread) that
2781 /// they should cancel any editing operations.
2782 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2783 cx.emit(ThreadEvent::CancelEditing);
2784 }
2785
2786 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2787 self.message_feedback.get(&message_id).copied()
2788 }
2789
2790 pub fn report_message_feedback(
2791 &mut self,
2792 message_id: MessageId,
2793 feedback: ThreadFeedback,
2794 cx: &mut Context<Self>,
2795 ) -> Task<Result<()>> {
2796 if self.message_feedback.get(&message_id) == Some(&feedback) {
2797 return Task::ready(Ok(()));
2798 }
2799
2800 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2801 let serialized_thread = self.serialize(cx);
2802 let thread_id = self.id().clone();
2803 let client = self.project.read(cx).client();
2804
2805 let enabled_tool_names: Vec<String> = self
2806 .profile
2807 .enabled_tools(cx)
2808 .iter()
2809 .map(|(name, _)| name.clone().into())
2810 .collect();
2811
2812 self.message_feedback.insert(message_id, feedback);
2813
2814 cx.notify();
2815
2816 let message_content = self
2817 .message(message_id)
2818 .map(|msg| msg.to_message_content())
2819 .unwrap_or_default();
2820
2821 cx.background_spawn(async move {
2822 let final_project_snapshot = final_project_snapshot.await;
2823 let serialized_thread = serialized_thread.await?;
2824 let thread_data =
2825 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2826
2827 let rating = match feedback {
2828 ThreadFeedback::Positive => "positive",
2829 ThreadFeedback::Negative => "negative",
2830 };
2831 telemetry::event!(
2832 "Assistant Thread Rated",
2833 rating,
2834 thread_id,
2835 enabled_tool_names,
2836 message_id = message_id.0,
2837 message_content,
2838 thread_data,
2839 final_project_snapshot
2840 );
2841 client.telemetry().flush_events().await;
2842
2843 Ok(())
2844 })
2845 }
2846
2847 /// Create a snapshot of the current project state including git information and unsaved buffers.
2848 fn project_snapshot(
2849 project: Entity<Project>,
2850 cx: &mut Context<Self>,
2851 ) -> Task<Arc<ProjectSnapshot>> {
2852 let git_store = project.read(cx).git_store().clone();
2853 let worktree_snapshots: Vec<_> = project
2854 .read(cx)
2855 .visible_worktrees(cx)
2856 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2857 .collect();
2858
2859 cx.spawn(async move |_, _| {
2860 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2861
2862 Arc::new(ProjectSnapshot {
2863 worktree_snapshots,
2864 timestamp: Utc::now(),
2865 })
2866 })
2867 }
2868
2869 fn worktree_snapshot(
2870 worktree: Entity<project::Worktree>,
2871 git_store: Entity<GitStore>,
2872 cx: &App,
2873 ) -> Task<WorktreeSnapshot> {
2874 cx.spawn(async move |cx| {
2875 // Get worktree path and snapshot
2876 let worktree_info = cx.update(|app_cx| {
2877 let worktree = worktree.read(app_cx);
2878 let path = worktree.abs_path().to_string_lossy().into_owned();
2879 let snapshot = worktree.snapshot();
2880 (path, snapshot)
2881 });
2882
2883 let Ok((worktree_path, _snapshot)) = worktree_info else {
2884 return WorktreeSnapshot {
2885 worktree_path: String::new(),
2886 git_state: None,
2887 };
2888 };
2889
2890 let git_state = git_store
2891 .update(cx, |git_store, cx| {
2892 git_store
2893 .repositories()
2894 .values()
2895 .find(|repo| {
2896 repo.read(cx)
2897 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2898 .is_some()
2899 })
2900 .cloned()
2901 })
2902 .ok()
2903 .flatten()
2904 .map(|repo| {
2905 repo.update(cx, |repo, _| {
2906 let current_branch =
2907 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2908 repo.send_job(None, |state, _| async move {
2909 let RepositoryState::Local { backend, .. } = state else {
2910 return GitState {
2911 remote_url: None,
2912 head_sha: None,
2913 current_branch,
2914 diff: None,
2915 };
2916 };
2917
2918 let remote_url = backend.remote_url("origin");
2919 let head_sha = backend.head_sha().await;
2920 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2921
2922 GitState {
2923 remote_url,
2924 head_sha,
2925 current_branch,
2926 diff,
2927 }
2928 })
2929 })
2930 });
2931
2932 let git_state = match git_state {
2933 Some(git_state) => match git_state.ok() {
2934 Some(git_state) => git_state.await.ok(),
2935 None => None,
2936 },
2937 None => None,
2938 };
2939
2940 WorktreeSnapshot {
2941 worktree_path,
2942 git_state,
2943 }
2944 })
2945 }
2946
2947 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2948 let mut markdown = Vec::new();
2949
2950 let summary = self.summary().or_default();
2951 writeln!(markdown, "# {summary}\n")?;
2952
2953 for message in self.messages() {
2954 writeln!(
2955 markdown,
2956 "## {role}\n",
2957 role = match message.role {
2958 Role::User => "User",
2959 Role::Assistant => "Agent",
2960 Role::System => "System",
2961 }
2962 )?;
2963
2964 if !message.loaded_context.text.is_empty() {
2965 writeln!(markdown, "{}", message.loaded_context.text)?;
2966 }
2967
2968 if !message.loaded_context.images.is_empty() {
2969 writeln!(
2970 markdown,
2971 "\n{} images attached as context.\n",
2972 message.loaded_context.images.len()
2973 )?;
2974 }
2975
2976 for segment in &message.segments {
2977 match segment {
2978 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2979 MessageSegment::Thinking { text, .. } => {
2980 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2981 }
2982 MessageSegment::RedactedThinking(_) => {}
2983 }
2984 }
2985
2986 for tool_use in self.tool_uses_for_message(message.id, cx) {
2987 writeln!(
2988 markdown,
2989 "**Use Tool: {} ({})**",
2990 tool_use.name, tool_use.id
2991 )?;
2992 writeln!(markdown, "```json")?;
2993 writeln!(
2994 markdown,
2995 "{}",
2996 serde_json::to_string_pretty(&tool_use.input)?
2997 )?;
2998 writeln!(markdown, "```")?;
2999 }
3000
3001 for tool_result in self.tool_results_for_message(message.id) {
3002 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
3003 if tool_result.is_error {
3004 write!(markdown, " (Error)")?;
3005 }
3006
3007 writeln!(markdown, "**\n")?;
3008 match &tool_result.content {
3009 LanguageModelToolResultContent::Text(text) => {
3010 writeln!(markdown, "{text}")?;
3011 }
3012 LanguageModelToolResultContent::Image(image) => {
3013 writeln!(markdown, "", image.source)?;
3014 }
3015 }
3016
3017 if let Some(output) = tool_result.output.as_ref() {
3018 writeln!(
3019 markdown,
3020 "\n\nDebug Output:\n\n```json\n{}\n```\n",
3021 serde_json::to_string_pretty(output)?
3022 )?;
3023 }
3024 }
3025 }
3026
3027 Ok(String::from_utf8_lossy(&markdown).to_string())
3028 }
3029
3030 pub fn keep_edits_in_range(
3031 &mut self,
3032 buffer: Entity<language::Buffer>,
3033 buffer_range: Range<language::Anchor>,
3034 cx: &mut Context<Self>,
3035 ) {
3036 self.action_log.update(cx, |action_log, cx| {
3037 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3038 });
3039 }
3040
3041 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3042 self.action_log
3043 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3044 }
3045
3046 pub fn reject_edits_in_ranges(
3047 &mut self,
3048 buffer: Entity<language::Buffer>,
3049 buffer_ranges: Vec<Range<language::Anchor>>,
3050 cx: &mut Context<Self>,
3051 ) -> Task<Result<()>> {
3052 self.action_log.update(cx, |action_log, cx| {
3053 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3054 })
3055 }
3056
3057 pub fn action_log(&self) -> &Entity<ActionLog> {
3058 &self.action_log
3059 }
3060
3061 pub fn project(&self) -> &Entity<Project> {
3062 &self.project
3063 }
3064
3065 pub fn cumulative_token_usage(&self) -> TokenUsage {
3066 self.cumulative_token_usage
3067 }
3068
3069 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3070 let Some(model) = self.configured_model.as_ref() else {
3071 return TotalTokenUsage::default();
3072 };
3073
3074 let max = model
3075 .model
3076 .max_token_count_for_mode(self.completion_mode().into());
3077
3078 let index = self
3079 .messages
3080 .iter()
3081 .position(|msg| msg.id == message_id)
3082 .unwrap_or(0);
3083
3084 if index == 0 {
3085 return TotalTokenUsage { total: 0, max };
3086 }
3087
3088 let token_usage = &self
3089 .request_token_usage
3090 .get(index - 1)
3091 .cloned()
3092 .unwrap_or_default();
3093
3094 TotalTokenUsage {
3095 total: token_usage.total_tokens(),
3096 max,
3097 }
3098 }
3099
3100 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3101 let model = self.configured_model.as_ref()?;
3102
3103 let max = model
3104 .model
3105 .max_token_count_for_mode(self.completion_mode().into());
3106
3107 if let Some(exceeded_error) = &self.exceeded_window_error
3108 && model.model.id() == exceeded_error.model_id
3109 {
3110 return Some(TotalTokenUsage {
3111 total: exceeded_error.token_count,
3112 max,
3113 });
3114 }
3115
3116 let total = self
3117 .token_usage_at_last_message()
3118 .unwrap_or_default()
3119 .total_tokens();
3120
3121 Some(TotalTokenUsage { total, max })
3122 }
3123
3124 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3125 self.request_token_usage
3126 .get(self.messages.len().saturating_sub(1))
3127 .or_else(|| self.request_token_usage.last())
3128 .cloned()
3129 }
3130
3131 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3132 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3133 self.request_token_usage
3134 .resize(self.messages.len(), placeholder);
3135
3136 if let Some(last) = self.request_token_usage.last_mut() {
3137 *last = token_usage;
3138 }
3139 }
3140
3141 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3142 self.project
3143 .read(cx)
3144 .user_store()
3145 .update(cx, |user_store, cx| {
3146 user_store.update_model_request_usage(
3147 ModelRequestUsage(RequestUsage {
3148 amount: amount as i32,
3149 limit,
3150 }),
3151 cx,
3152 )
3153 });
3154 }
3155
3156 pub fn deny_tool_use(
3157 &mut self,
3158 tool_use_id: LanguageModelToolUseId,
3159 tool_name: Arc<str>,
3160 window: Option<AnyWindowHandle>,
3161 cx: &mut Context<Self>,
3162 ) {
3163 let err = Err(anyhow::anyhow!(
3164 "Permission to run tool action denied by user"
3165 ));
3166
3167 self.tool_use.insert_tool_output(
3168 tool_use_id.clone(),
3169 tool_name,
3170 err,
3171 self.configured_model.as_ref(),
3172 self.completion_mode,
3173 );
3174 self.tool_finished(tool_use_id, None, true, window, cx);
3175 }
3176}
3177
3178#[derive(Debug, Clone, Error)]
3179pub enum ThreadError {
3180 #[error("Payment required")]
3181 PaymentRequired,
3182 #[error("Model request limit reached")]
3183 ModelRequestLimitReached { plan: Plan },
3184 #[error("Message {header}: {message}")]
3185 Message {
3186 header: SharedString,
3187 message: SharedString,
3188 },
3189 #[error("Retryable error: {message}")]
3190 RetryableError {
3191 message: SharedString,
3192 can_enable_burn_mode: bool,
3193 },
3194}
3195
3196#[derive(Debug, Clone)]
3197pub enum ThreadEvent {
3198 ShowError(ThreadError),
3199 StreamedCompletion,
3200 ReceivedTextChunk,
3201 NewRequest,
3202 StreamedAssistantText(MessageId, String),
3203 StreamedAssistantThinking(MessageId, String),
3204 StreamedToolUse {
3205 tool_use_id: LanguageModelToolUseId,
3206 ui_text: Arc<str>,
3207 input: serde_json::Value,
3208 },
3209 MissingToolUse {
3210 tool_use_id: LanguageModelToolUseId,
3211 ui_text: Arc<str>,
3212 },
3213 InvalidToolInput {
3214 tool_use_id: LanguageModelToolUseId,
3215 ui_text: Arc<str>,
3216 invalid_input_json: Arc<str>,
3217 },
3218 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3219 MessageAdded(MessageId),
3220 MessageEdited(MessageId),
3221 MessageDeleted(MessageId),
3222 SummaryGenerated,
3223 SummaryChanged,
3224 UsePendingTools {
3225 tool_uses: Vec<PendingToolUse>,
3226 },
3227 ToolFinished {
3228 #[allow(unused)]
3229 tool_use_id: LanguageModelToolUseId,
3230 /// The pending tool use that corresponds to this tool.
3231 pending_tool_use: Option<PendingToolUse>,
3232 },
3233 CheckpointChanged,
3234 ToolConfirmationNeeded,
3235 ToolUseLimitReached,
3236 CancelEditing,
3237 CompletionCanceled,
3238 ProfileChanged,
3239}
3240
3241impl EventEmitter<ThreadEvent> for Thread {}
3242
3243struct PendingCompletion {
3244 id: usize,
3245 queue_state: QueueState,
3246 _task: Task<()>,
3247}
3248
3249#[cfg(test)]
3250mod tests {
3251 use super::*;
3252 use crate::{
3253 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3254 };
3255
3256 // Test-specific constants
3257 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3258 use agent_settings::{AgentProfileId, AgentSettings};
3259 use assistant_tool::ToolRegistry;
3260 use assistant_tools;
3261 use fs::Fs;
3262 use futures::StreamExt;
3263 use futures::future::BoxFuture;
3264 use futures::stream::BoxStream;
3265 use gpui::TestAppContext;
3266 use http_client;
3267 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3268 use language_model::{
3269 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3270 LanguageModelProviderName, LanguageModelToolChoice,
3271 };
3272 use parking_lot::Mutex;
3273 use project::{FakeFs, Project};
3274 use prompt_store::PromptBuilder;
3275 use serde_json::json;
3276 use settings::{LanguageModelParameters, Settings, SettingsStore};
3277 use std::sync::Arc;
3278 use std::time::Duration;
3279 use theme::ThemeSettings;
3280 use util::path;
3281 use workspace::Workspace;
3282
3283 #[gpui::test]
3284 async fn test_message_with_context(cx: &mut TestAppContext) {
3285 let fs = init_test_settings(cx);
3286
3287 let project = create_test_project(
3288 &fs,
3289 cx,
3290 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3291 )
3292 .await;
3293
3294 let (_workspace, _thread_store, thread, context_store, model) =
3295 setup_test_environment(cx, project.clone()).await;
3296
3297 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3298 .await
3299 .unwrap();
3300
3301 let context =
3302 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3303 let loaded_context = cx
3304 .update(|cx| load_context(vec![context], &project, &None, cx))
3305 .await;
3306
3307 // Insert user message with context
3308 let message_id = thread.update(cx, |thread, cx| {
3309 thread.insert_user_message(
3310 "Please explain this code",
3311 loaded_context,
3312 None,
3313 Vec::new(),
3314 cx,
3315 )
3316 });
3317
3318 // Check content and context in message object
3319 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3320
3321 // Use different path format strings based on platform for the test
3322 #[cfg(windows)]
3323 let path_part = r"test\code.rs";
3324 #[cfg(not(windows))]
3325 let path_part = "test/code.rs";
3326
3327 let expected_context = format!(
3328 r#"
3329<context>
3330The following items were attached by the user. They are up-to-date and don't need to be re-read.
3331
3332<files>
3333```rs {path_part}
3334fn main() {{
3335 println!("Hello, world!");
3336}}
3337```
3338</files>
3339</context>
3340"#
3341 );
3342
3343 assert_eq!(message.role, Role::User);
3344 assert_eq!(message.segments.len(), 1);
3345 assert_eq!(
3346 message.segments[0],
3347 MessageSegment::Text("Please explain this code".to_string())
3348 );
3349 assert_eq!(message.loaded_context.text, expected_context);
3350
3351 // Check message in request
3352 let request = thread.update(cx, |thread, cx| {
3353 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3354 });
3355
3356 assert_eq!(request.messages.len(), 2);
3357 let expected_full_message = format!("{}Please explain this code", expected_context);
3358 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3359 }
3360
3361 #[gpui::test]
3362 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3363 let fs = init_test_settings(cx);
3364
3365 let project = create_test_project(
3366 &fs,
3367 cx,
3368 json!({
3369 "file1.rs": "fn function1() {}\n",
3370 "file2.rs": "fn function2() {}\n",
3371 "file3.rs": "fn function3() {}\n",
3372 "file4.rs": "fn function4() {}\n",
3373 }),
3374 )
3375 .await;
3376
3377 let (_, _thread_store, thread, context_store, model) =
3378 setup_test_environment(cx, project.clone()).await;
3379
3380 // First message with context 1
3381 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3382 .await
3383 .unwrap();
3384 let new_contexts = context_store.update(cx, |store, cx| {
3385 store.new_context_for_thread(thread.read(cx), None)
3386 });
3387 assert_eq!(new_contexts.len(), 1);
3388 let loaded_context = cx
3389 .update(|cx| load_context(new_contexts, &project, &None, cx))
3390 .await;
3391 let message1_id = thread.update(cx, |thread, cx| {
3392 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3393 });
3394
3395 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3396 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3397 .await
3398 .unwrap();
3399 let new_contexts = context_store.update(cx, |store, cx| {
3400 store.new_context_for_thread(thread.read(cx), None)
3401 });
3402 assert_eq!(new_contexts.len(), 1);
3403 let loaded_context = cx
3404 .update(|cx| load_context(new_contexts, &project, &None, cx))
3405 .await;
3406 let message2_id = thread.update(cx, |thread, cx| {
3407 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3408 });
3409
3410 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3411 //
3412 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3413 .await
3414 .unwrap();
3415 let new_contexts = context_store.update(cx, |store, cx| {
3416 store.new_context_for_thread(thread.read(cx), None)
3417 });
3418 assert_eq!(new_contexts.len(), 1);
3419 let loaded_context = cx
3420 .update(|cx| load_context(new_contexts, &project, &None, cx))
3421 .await;
3422 let message3_id = thread.update(cx, |thread, cx| {
3423 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3424 });
3425
3426 // Check what contexts are included in each message
3427 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3428 (
3429 thread.message(message1_id).unwrap().clone(),
3430 thread.message(message2_id).unwrap().clone(),
3431 thread.message(message3_id).unwrap().clone(),
3432 )
3433 });
3434
3435 // First message should include context 1
3436 assert!(message1.loaded_context.text.contains("file1.rs"));
3437
3438 // Second message should include only context 2 (not 1)
3439 assert!(!message2.loaded_context.text.contains("file1.rs"));
3440 assert!(message2.loaded_context.text.contains("file2.rs"));
3441
3442 // Third message should include only context 3 (not 1 or 2)
3443 assert!(!message3.loaded_context.text.contains("file1.rs"));
3444 assert!(!message3.loaded_context.text.contains("file2.rs"));
3445 assert!(message3.loaded_context.text.contains("file3.rs"));
3446
3447 // Check entire request to make sure all contexts are properly included
3448 let request = thread.update(cx, |thread, cx| {
3449 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3450 });
3451
3452 // The request should contain all 3 messages
3453 assert_eq!(request.messages.len(), 4);
3454
3455 // Check that the contexts are properly formatted in each message
3456 assert!(request.messages[1].string_contents().contains("file1.rs"));
3457 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3458 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3459
3460 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3461 assert!(request.messages[2].string_contents().contains("file2.rs"));
3462 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3463
3464 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3465 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3466 assert!(request.messages[3].string_contents().contains("file3.rs"));
3467
3468 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3469 .await
3470 .unwrap();
3471 let new_contexts = context_store.update(cx, |store, cx| {
3472 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3473 });
3474 assert_eq!(new_contexts.len(), 3);
3475 let loaded_context = cx
3476 .update(|cx| load_context(new_contexts, &project, &None, cx))
3477 .await
3478 .loaded_context;
3479
3480 assert!(!loaded_context.text.contains("file1.rs"));
3481 assert!(loaded_context.text.contains("file2.rs"));
3482 assert!(loaded_context.text.contains("file3.rs"));
3483 assert!(loaded_context.text.contains("file4.rs"));
3484
3485 let new_contexts = context_store.update(cx, |store, cx| {
3486 // Remove file4.rs
3487 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3488 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3489 });
3490 assert_eq!(new_contexts.len(), 2);
3491 let loaded_context = cx
3492 .update(|cx| load_context(new_contexts, &project, &None, cx))
3493 .await
3494 .loaded_context;
3495
3496 assert!(!loaded_context.text.contains("file1.rs"));
3497 assert!(loaded_context.text.contains("file2.rs"));
3498 assert!(loaded_context.text.contains("file3.rs"));
3499 assert!(!loaded_context.text.contains("file4.rs"));
3500
3501 let new_contexts = context_store.update(cx, |store, cx| {
3502 // Remove file3.rs
3503 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3504 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3505 });
3506 assert_eq!(new_contexts.len(), 1);
3507 let loaded_context = cx
3508 .update(|cx| load_context(new_contexts, &project, &None, cx))
3509 .await
3510 .loaded_context;
3511
3512 assert!(!loaded_context.text.contains("file1.rs"));
3513 assert!(loaded_context.text.contains("file2.rs"));
3514 assert!(!loaded_context.text.contains("file3.rs"));
3515 assert!(!loaded_context.text.contains("file4.rs"));
3516 }
3517
3518 #[gpui::test]
3519 async fn test_message_without_files(cx: &mut TestAppContext) {
3520 let fs = init_test_settings(cx);
3521
3522 let project = create_test_project(
3523 &fs,
3524 cx,
3525 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3526 )
3527 .await;
3528
3529 let (_, _thread_store, thread, _context_store, model) =
3530 setup_test_environment(cx, project.clone()).await;
3531
3532 // Insert user message without any context (empty context vector)
3533 let message_id = thread.update(cx, |thread, cx| {
3534 thread.insert_user_message(
3535 "What is the best way to learn Rust?",
3536 ContextLoadResult::default(),
3537 None,
3538 Vec::new(),
3539 cx,
3540 )
3541 });
3542
3543 // Check content and context in message object
3544 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3545
3546 // Context should be empty when no files are included
3547 assert_eq!(message.role, Role::User);
3548 assert_eq!(message.segments.len(), 1);
3549 assert_eq!(
3550 message.segments[0],
3551 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3552 );
3553 assert_eq!(message.loaded_context.text, "");
3554
3555 // Check message in request
3556 let request = thread.update(cx, |thread, cx| {
3557 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3558 });
3559
3560 assert_eq!(request.messages.len(), 2);
3561 assert_eq!(
3562 request.messages[1].string_contents(),
3563 "What is the best way to learn Rust?"
3564 );
3565
3566 // Add second message, also without context
3567 let message2_id = thread.update(cx, |thread, cx| {
3568 thread.insert_user_message(
3569 "Are there any good books?",
3570 ContextLoadResult::default(),
3571 None,
3572 Vec::new(),
3573 cx,
3574 )
3575 });
3576
3577 let message2 =
3578 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3579 assert_eq!(message2.loaded_context.text, "");
3580
3581 // Check that both messages appear in the request
3582 let request = thread.update(cx, |thread, cx| {
3583 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3584 });
3585
3586 assert_eq!(request.messages.len(), 3);
3587 assert_eq!(
3588 request.messages[1].string_contents(),
3589 "What is the best way to learn Rust?"
3590 );
3591 assert_eq!(
3592 request.messages[2].string_contents(),
3593 "Are there any good books?"
3594 );
3595 }
3596
3597 #[gpui::test]
3598 #[ignore] // turn this test on when project_notifications tool is re-enabled
3599 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3600 let fs = init_test_settings(cx);
3601
3602 let project = create_test_project(
3603 &fs,
3604 cx,
3605 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3606 )
3607 .await;
3608
3609 let (_workspace, _thread_store, thread, context_store, model) =
3610 setup_test_environment(cx, project.clone()).await;
3611
3612 // Add a buffer to the context. This will be a tracked buffer
3613 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3614 .await
3615 .unwrap();
3616
3617 let context = context_store
3618 .read_with(cx, |store, _| store.context().next().cloned())
3619 .unwrap();
3620 let loaded_context = cx
3621 .update(|cx| load_context(vec![context], &project, &None, cx))
3622 .await;
3623
3624 // Insert user message and assistant response
3625 thread.update(cx, |thread, cx| {
3626 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3627 thread.insert_assistant_message(
3628 vec![MessageSegment::Text("This code prints 42.".into())],
3629 cx,
3630 );
3631 });
3632 cx.run_until_parked();
3633
3634 // We shouldn't have a stale buffer notification yet
3635 let notifications = thread.read_with(cx, |thread, _| {
3636 find_tool_uses(thread, "project_notifications")
3637 });
3638 assert!(
3639 notifications.is_empty(),
3640 "Should not have stale buffer notification before buffer is modified"
3641 );
3642
3643 // Modify the buffer
3644 buffer.update(cx, |buffer, cx| {
3645 buffer.edit(
3646 [(1..1, "\n println!(\"Added a new line\");\n")],
3647 None,
3648 cx,
3649 );
3650 });
3651
3652 // Insert another user message
3653 thread.update(cx, |thread, cx| {
3654 thread.insert_user_message(
3655 "What does the code do now?",
3656 ContextLoadResult::default(),
3657 None,
3658 Vec::new(),
3659 cx,
3660 )
3661 });
3662 cx.run_until_parked();
3663
3664 // Check for the stale buffer warning
3665 thread.update(cx, |thread, cx| {
3666 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3667 });
3668 cx.run_until_parked();
3669
3670 let notifications = thread.read_with(cx, |thread, _cx| {
3671 find_tool_uses(thread, "project_notifications")
3672 });
3673
3674 let [notification] = notifications.as_slice() else {
3675 panic!("Should have a `project_notifications` tool use");
3676 };
3677
3678 let Some(notification_content) = notification.content.to_str() else {
3679 panic!("`project_notifications` should return text");
3680 };
3681
3682 assert!(notification_content.contains("These files have changed since the last read:"));
3683 assert!(notification_content.contains("code.rs"));
3684
3685 // Insert another user message and flush notifications again
3686 thread.update(cx, |thread, cx| {
3687 thread.insert_user_message(
3688 "Can you tell me more?",
3689 ContextLoadResult::default(),
3690 None,
3691 Vec::new(),
3692 cx,
3693 )
3694 });
3695
3696 thread.update(cx, |thread, cx| {
3697 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3698 });
3699 cx.run_until_parked();
3700
3701 // There should be no new notifications (we already flushed one)
3702 let notifications = thread.read_with(cx, |thread, _cx| {
3703 find_tool_uses(thread, "project_notifications")
3704 });
3705
3706 assert_eq!(
3707 notifications.len(),
3708 1,
3709 "Should still have only one notification after second flush - no duplicates"
3710 );
3711 }
3712
3713 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3714 thread
3715 .messages()
3716 .flat_map(|message| {
3717 thread
3718 .tool_results_for_message(message.id)
3719 .into_iter()
3720 .filter(|result| result.tool_name == tool_name.into())
3721 .cloned()
3722 .collect::<Vec<_>>()
3723 })
3724 .collect()
3725 }
3726
3727 #[gpui::test]
3728 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3729 let fs = init_test_settings(cx);
3730
3731 let project = create_test_project(
3732 &fs,
3733 cx,
3734 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3735 )
3736 .await;
3737
3738 let (_workspace, thread_store, thread, _context_store, _model) =
3739 setup_test_environment(cx, project.clone()).await;
3740
3741 // Check that we are starting with the default profile
3742 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3743 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3744 assert_eq!(
3745 profile,
3746 AgentProfile::new(AgentProfileId::default(), tool_set)
3747 );
3748 }
3749
3750 #[gpui::test]
3751 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3752 let fs = init_test_settings(cx);
3753
3754 let project = create_test_project(
3755 &fs,
3756 cx,
3757 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3758 )
3759 .await;
3760
3761 let (_workspace, thread_store, thread, _context_store, _model) =
3762 setup_test_environment(cx, project.clone()).await;
3763
3764 // Profile gets serialized with default values
3765 let serialized = thread
3766 .update(cx, |thread, cx| thread.serialize(cx))
3767 .await
3768 .unwrap();
3769
3770 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3771
3772 let deserialized = cx.update(|cx| {
3773 thread.update(cx, |thread, cx| {
3774 Thread::deserialize(
3775 thread.id.clone(),
3776 serialized,
3777 thread.project.clone(),
3778 thread.tools.clone(),
3779 thread.prompt_builder.clone(),
3780 thread.project_context.clone(),
3781 None,
3782 cx,
3783 )
3784 })
3785 });
3786 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3787
3788 assert_eq!(
3789 deserialized.profile,
3790 AgentProfile::new(AgentProfileId::default(), tool_set)
3791 );
3792 }
3793
3794 #[gpui::test]
3795 async fn test_temperature_setting(cx: &mut TestAppContext) {
3796 let fs = init_test_settings(cx);
3797
3798 let project = create_test_project(
3799 &fs,
3800 cx,
3801 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3802 )
3803 .await;
3804
3805 let (_workspace, _thread_store, thread, _context_store, model) =
3806 setup_test_environment(cx, project.clone()).await;
3807
3808 // Both model and provider
3809 cx.update(|cx| {
3810 AgentSettings::override_global(
3811 AgentSettings {
3812 model_parameters: vec![LanguageModelParameters {
3813 provider: Some(model.provider_id().0.to_string().into()),
3814 model: Some(model.id().0),
3815 temperature: Some(0.66),
3816 }],
3817 ..AgentSettings::get_global(cx).clone()
3818 },
3819 cx,
3820 );
3821 });
3822
3823 let request = thread.update(cx, |thread, cx| {
3824 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3825 });
3826 assert_eq!(request.temperature, Some(0.66));
3827
3828 // Only model
3829 cx.update(|cx| {
3830 AgentSettings::override_global(
3831 AgentSettings {
3832 model_parameters: vec![LanguageModelParameters {
3833 provider: None,
3834 model: Some(model.id().0),
3835 temperature: Some(0.66),
3836 }],
3837 ..AgentSettings::get_global(cx).clone()
3838 },
3839 cx,
3840 );
3841 });
3842
3843 let request = thread.update(cx, |thread, cx| {
3844 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3845 });
3846 assert_eq!(request.temperature, Some(0.66));
3847
3848 // Only provider
3849 cx.update(|cx| {
3850 AgentSettings::override_global(
3851 AgentSettings {
3852 model_parameters: vec![LanguageModelParameters {
3853 provider: Some(model.provider_id().0.to_string().into()),
3854 model: None,
3855 temperature: Some(0.66),
3856 }],
3857 ..AgentSettings::get_global(cx).clone()
3858 },
3859 cx,
3860 );
3861 });
3862
3863 let request = thread.update(cx, |thread, cx| {
3864 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3865 });
3866 assert_eq!(request.temperature, Some(0.66));
3867
3868 // Same model name, different provider
3869 cx.update(|cx| {
3870 AgentSettings::override_global(
3871 AgentSettings {
3872 model_parameters: vec![LanguageModelParameters {
3873 provider: Some("anthropic".into()),
3874 model: Some(model.id().0),
3875 temperature: Some(0.66),
3876 }],
3877 ..AgentSettings::get_global(cx).clone()
3878 },
3879 cx,
3880 );
3881 });
3882
3883 let request = thread.update(cx, |thread, cx| {
3884 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3885 });
3886 assert_eq!(request.temperature, None);
3887 }
3888
3889 #[gpui::test]
3890 async fn test_thread_summary(cx: &mut TestAppContext) {
3891 let fs = init_test_settings(cx);
3892
3893 let project = create_test_project(&fs, cx, json!({})).await;
3894
3895 let (_, _thread_store, thread, _context_store, model) =
3896 setup_test_environment(cx, project.clone()).await;
3897
3898 // Initial state should be pending
3899 thread.read_with(cx, |thread, _| {
3900 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3901 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3902 });
3903
3904 // Manually setting the summary should not be allowed in this state
3905 thread.update(cx, |thread, cx| {
3906 thread.set_summary("This should not work", cx);
3907 });
3908
3909 thread.read_with(cx, |thread, _| {
3910 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3911 });
3912
3913 // Send a message
3914 thread.update(cx, |thread, cx| {
3915 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3916 thread.send_to_model(
3917 model.clone(),
3918 CompletionIntent::ThreadSummarization,
3919 None,
3920 cx,
3921 );
3922 });
3923
3924 let fake_model = model.as_fake();
3925 simulate_successful_response(fake_model, cx);
3926
3927 // Should start generating summary when there are >= 2 messages
3928 thread.read_with(cx, |thread, _| {
3929 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3930 });
3931
3932 // Should not be able to set the summary while generating
3933 thread.update(cx, |thread, cx| {
3934 thread.set_summary("This should not work either", cx);
3935 });
3936
3937 thread.read_with(cx, |thread, _| {
3938 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3939 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3940 });
3941
3942 cx.run_until_parked();
3943 fake_model.send_last_completion_stream_text_chunk("Brief");
3944 fake_model.send_last_completion_stream_text_chunk(" Introduction");
3945 fake_model.end_last_completion_stream();
3946 cx.run_until_parked();
3947
3948 // Summary should be set
3949 thread.read_with(cx, |thread, _| {
3950 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3951 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3952 });
3953
3954 // Now we should be able to set a summary
3955 thread.update(cx, |thread, cx| {
3956 thread.set_summary("Brief Intro", cx);
3957 });
3958
3959 thread.read_with(cx, |thread, _| {
3960 assert_eq!(thread.summary().or_default(), "Brief Intro");
3961 });
3962
3963 // Test setting an empty summary (should default to DEFAULT)
3964 thread.update(cx, |thread, cx| {
3965 thread.set_summary("", cx);
3966 });
3967
3968 thread.read_with(cx, |thread, _| {
3969 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3970 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3971 });
3972 }
3973
3974 #[gpui::test]
3975 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3976 let fs = init_test_settings(cx);
3977
3978 let project = create_test_project(&fs, cx, json!({})).await;
3979
3980 let (_, _thread_store, thread, _context_store, model) =
3981 setup_test_environment(cx, project.clone()).await;
3982
3983 test_summarize_error(&model, &thread, cx);
3984
3985 // Now we should be able to set a summary
3986 thread.update(cx, |thread, cx| {
3987 thread.set_summary("Brief Intro", cx);
3988 });
3989
3990 thread.read_with(cx, |thread, _| {
3991 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3992 assert_eq!(thread.summary().or_default(), "Brief Intro");
3993 });
3994 }
3995
3996 #[gpui::test]
3997 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3998 let fs = init_test_settings(cx);
3999
4000 let project = create_test_project(&fs, cx, json!({})).await;
4001
4002 let (_, _thread_store, thread, _context_store, model) =
4003 setup_test_environment(cx, project.clone()).await;
4004
4005 test_summarize_error(&model, &thread, cx);
4006
4007 // Sending another message should not trigger another summarize request
4008 thread.update(cx, |thread, cx| {
4009 thread.insert_user_message(
4010 "How are you?",
4011 ContextLoadResult::default(),
4012 None,
4013 vec![],
4014 cx,
4015 );
4016 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4017 });
4018
4019 let fake_model = model.as_fake();
4020 simulate_successful_response(fake_model, cx);
4021
4022 thread.read_with(cx, |thread, _| {
4023 // State is still Error, not Generating
4024 assert!(matches!(thread.summary(), ThreadSummary::Error));
4025 });
4026
4027 // But the summarize request can be invoked manually
4028 thread.update(cx, |thread, cx| {
4029 thread.summarize(cx);
4030 });
4031
4032 thread.read_with(cx, |thread, _| {
4033 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4034 });
4035
4036 cx.run_until_parked();
4037 fake_model.send_last_completion_stream_text_chunk("A successful summary");
4038 fake_model.end_last_completion_stream();
4039 cx.run_until_parked();
4040
4041 thread.read_with(cx, |thread, _| {
4042 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4043 assert_eq!(thread.summary().or_default(), "A successful summary");
4044 });
4045 }
4046
4047 // Helper to create a model that returns errors
4048 enum TestError {
4049 Overloaded,
4050 InternalServerError,
4051 }
4052
4053 struct ErrorInjector {
4054 inner: Arc<FakeLanguageModel>,
4055 error_type: TestError,
4056 }
4057
4058 impl ErrorInjector {
4059 fn new(error_type: TestError) -> Self {
4060 Self {
4061 inner: Arc::new(FakeLanguageModel::default()),
4062 error_type,
4063 }
4064 }
4065 }
4066
4067 impl LanguageModel for ErrorInjector {
4068 fn id(&self) -> LanguageModelId {
4069 self.inner.id()
4070 }
4071
4072 fn name(&self) -> LanguageModelName {
4073 self.inner.name()
4074 }
4075
4076 fn provider_id(&self) -> LanguageModelProviderId {
4077 self.inner.provider_id()
4078 }
4079
4080 fn provider_name(&self) -> LanguageModelProviderName {
4081 self.inner.provider_name()
4082 }
4083
4084 fn supports_tools(&self) -> bool {
4085 self.inner.supports_tools()
4086 }
4087
4088 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4089 self.inner.supports_tool_choice(choice)
4090 }
4091
4092 fn supports_images(&self) -> bool {
4093 self.inner.supports_images()
4094 }
4095
4096 fn telemetry_id(&self) -> String {
4097 self.inner.telemetry_id()
4098 }
4099
4100 fn max_token_count(&self) -> u64 {
4101 self.inner.max_token_count()
4102 }
4103
4104 fn count_tokens(
4105 &self,
4106 request: LanguageModelRequest,
4107 cx: &App,
4108 ) -> BoxFuture<'static, Result<u64>> {
4109 self.inner.count_tokens(request, cx)
4110 }
4111
4112 fn stream_completion(
4113 &self,
4114 _request: LanguageModelRequest,
4115 _cx: &AsyncApp,
4116 ) -> BoxFuture<
4117 'static,
4118 Result<
4119 BoxStream<
4120 'static,
4121 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4122 >,
4123 LanguageModelCompletionError,
4124 >,
4125 > {
4126 let error = match self.error_type {
4127 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4128 provider: self.provider_name(),
4129 retry_after: None,
4130 },
4131 TestError::InternalServerError => {
4132 LanguageModelCompletionError::ApiInternalServerError {
4133 provider: self.provider_name(),
4134 message: "I'm a teapot orbiting the sun".to_string(),
4135 }
4136 }
4137 };
4138 async move {
4139 let stream = futures::stream::once(async move { Err(error) });
4140 Ok(stream.boxed())
4141 }
4142 .boxed()
4143 }
4144
4145 fn as_fake(&self) -> &FakeLanguageModel {
4146 &self.inner
4147 }
4148 }
4149
4150 #[gpui::test]
4151 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4152 let fs = init_test_settings(cx);
4153
4154 let project = create_test_project(&fs, cx, json!({})).await;
4155 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4156
4157 // Enable Burn Mode to allow retries
4158 thread.update(cx, |thread, _| {
4159 thread.set_completion_mode(CompletionMode::Burn);
4160 });
4161
4162 // Create model that returns overloaded error
4163 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4164
4165 // Insert a user message
4166 thread.update(cx, |thread, cx| {
4167 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4168 });
4169
4170 // Start completion
4171 thread.update(cx, |thread, cx| {
4172 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4173 });
4174
4175 cx.run_until_parked();
4176
4177 thread.read_with(cx, |thread, _| {
4178 assert!(thread.retry_state.is_some(), "Should have retry state");
4179 let retry_state = thread.retry_state.as_ref().unwrap();
4180 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4181 assert_eq!(
4182 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4183 "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4184 );
4185 });
4186
4187 // Check that a retry message was added
4188 thread.read_with(cx, |thread, _| {
4189 let mut messages = thread.messages();
4190 assert!(
4191 messages.any(|msg| {
4192 msg.role == Role::System
4193 && msg.ui_only
4194 && msg.segments.iter().any(|seg| {
4195 if let MessageSegment::Text(text) = seg {
4196 text.contains("overloaded")
4197 && text
4198 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4199 } else {
4200 false
4201 }
4202 })
4203 }),
4204 "Should have added a system retry message"
4205 );
4206 });
4207
4208 let retry_count = thread.update(cx, |thread, _| {
4209 thread
4210 .messages
4211 .iter()
4212 .filter(|m| {
4213 m.ui_only
4214 && m.segments.iter().any(|s| {
4215 if let MessageSegment::Text(text) = s {
4216 text.contains("Retrying") && text.contains("seconds")
4217 } else {
4218 false
4219 }
4220 })
4221 })
4222 .count()
4223 });
4224
4225 assert_eq!(retry_count, 1, "Should have one retry message");
4226 }
4227
4228 #[gpui::test]
4229 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4230 let fs = init_test_settings(cx);
4231
4232 let project = create_test_project(&fs, cx, json!({})).await;
4233 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4234
4235 // Enable Burn Mode to allow retries
4236 thread.update(cx, |thread, _| {
4237 thread.set_completion_mode(CompletionMode::Burn);
4238 });
4239
4240 // Create model that returns internal server error
4241 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4242
4243 // Insert a user message
4244 thread.update(cx, |thread, cx| {
4245 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4246 });
4247
4248 // Start completion
4249 thread.update(cx, |thread, cx| {
4250 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4251 });
4252
4253 cx.run_until_parked();
4254
4255 // Check retry state on thread
4256 thread.read_with(cx, |thread, _| {
4257 assert!(thread.retry_state.is_some(), "Should have retry state");
4258 let retry_state = thread.retry_state.as_ref().unwrap();
4259 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4260 assert_eq!(
4261 retry_state.max_attempts, 3,
4262 "Should have correct max attempts"
4263 );
4264 });
4265
4266 // Check that a retry message was added with provider name
4267 thread.read_with(cx, |thread, _| {
4268 let mut messages = thread.messages();
4269 assert!(
4270 messages.any(|msg| {
4271 msg.role == Role::System
4272 && msg.ui_only
4273 && msg.segments.iter().any(|seg| {
4274 if let MessageSegment::Text(text) = seg {
4275 text.contains("internal")
4276 && text.contains("Fake")
4277 && text.contains("Retrying")
4278 && text.contains("attempt 1 of 3")
4279 && text.contains("seconds")
4280 } else {
4281 false
4282 }
4283 })
4284 }),
4285 "Should have added a system retry message with provider name"
4286 );
4287 });
4288
4289 // Count retry messages
4290 let retry_count = thread.update(cx, |thread, _| {
4291 thread
4292 .messages
4293 .iter()
4294 .filter(|m| {
4295 m.ui_only
4296 && m.segments.iter().any(|s| {
4297 if let MessageSegment::Text(text) = s {
4298 text.contains("Retrying") && text.contains("seconds")
4299 } else {
4300 false
4301 }
4302 })
4303 })
4304 .count()
4305 });
4306
4307 assert_eq!(retry_count, 1, "Should have one retry message");
4308 }
4309
4310 #[gpui::test]
4311 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4312 let fs = init_test_settings(cx);
4313
4314 let project = create_test_project(&fs, cx, json!({})).await;
4315 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4316
4317 // Enable Burn Mode to allow retries
4318 thread.update(cx, |thread, _| {
4319 thread.set_completion_mode(CompletionMode::Burn);
4320 });
4321
4322 // Create model that returns internal server error
4323 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4324
4325 // Insert a user message
4326 thread.update(cx, |thread, cx| {
4327 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4328 });
4329
4330 // Track retry events and completion count
4331 // Track completion events
4332 let completion_count = Arc::new(Mutex::new(0));
4333 let completion_count_clone = completion_count.clone();
4334
4335 let _subscription = thread.update(cx, |_, cx| {
4336 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4337 if let ThreadEvent::NewRequest = event {
4338 *completion_count_clone.lock() += 1;
4339 }
4340 })
4341 });
4342
4343 // First attempt
4344 thread.update(cx, |thread, cx| {
4345 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4346 });
4347 cx.run_until_parked();
4348
4349 // Should have scheduled first retry - count retry messages
4350 let retry_count = thread.update(cx, |thread, _| {
4351 thread
4352 .messages
4353 .iter()
4354 .filter(|m| {
4355 m.ui_only
4356 && m.segments.iter().any(|s| {
4357 if let MessageSegment::Text(text) = s {
4358 text.contains("Retrying") && text.contains("seconds")
4359 } else {
4360 false
4361 }
4362 })
4363 })
4364 .count()
4365 });
4366 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4367
4368 // Check retry state
4369 thread.read_with(cx, |thread, _| {
4370 assert!(thread.retry_state.is_some(), "Should have retry state");
4371 let retry_state = thread.retry_state.as_ref().unwrap();
4372 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4373 assert_eq!(
4374 retry_state.max_attempts, 3,
4375 "Internal server errors should retry up to 3 times"
4376 );
4377 });
4378
4379 // Advance clock for first retry
4380 cx.executor().advance_clock(BASE_RETRY_DELAY);
4381 cx.run_until_parked();
4382
4383 // Advance clock for second retry
4384 cx.executor().advance_clock(BASE_RETRY_DELAY);
4385 cx.run_until_parked();
4386
4387 // Advance clock for third retry
4388 cx.executor().advance_clock(BASE_RETRY_DELAY);
4389 cx.run_until_parked();
4390
4391 // Should have completed all retries - count retry messages
4392 let retry_count = thread.update(cx, |thread, _| {
4393 thread
4394 .messages
4395 .iter()
4396 .filter(|m| {
4397 m.ui_only
4398 && m.segments.iter().any(|s| {
4399 if let MessageSegment::Text(text) = s {
4400 text.contains("Retrying") && text.contains("seconds")
4401 } else {
4402 false
4403 }
4404 })
4405 })
4406 .count()
4407 });
4408 assert_eq!(
4409 retry_count, 3,
4410 "Should have 3 retries for internal server errors"
4411 );
4412
4413 // For internal server errors, we retry 3 times and then give up
4414 // Check that retry_state is cleared after all retries
4415 thread.read_with(cx, |thread, _| {
4416 assert!(
4417 thread.retry_state.is_none(),
4418 "Retry state should be cleared after all retries"
4419 );
4420 });
4421
4422 // Verify total attempts (1 initial + 3 retries)
4423 assert_eq!(
4424 *completion_count.lock(),
4425 4,
4426 "Should have attempted once plus 3 retries"
4427 );
4428 }
4429
4430 #[gpui::test]
4431 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4432 let fs = init_test_settings(cx);
4433
4434 let project = create_test_project(&fs, cx, json!({})).await;
4435 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4436
4437 // Enable Burn Mode to allow retries
4438 thread.update(cx, |thread, _| {
4439 thread.set_completion_mode(CompletionMode::Burn);
4440 });
4441
4442 // Create model that returns overloaded error
4443 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4444
4445 // Insert a user message
4446 thread.update(cx, |thread, cx| {
4447 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4448 });
4449
4450 // Track events
4451 let stopped_with_error = Arc::new(Mutex::new(false));
4452 let stopped_with_error_clone = stopped_with_error.clone();
4453
4454 let _subscription = thread.update(cx, |_, cx| {
4455 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4456 if let ThreadEvent::Stopped(Err(_)) = event {
4457 *stopped_with_error_clone.lock() = true;
4458 }
4459 })
4460 });
4461
4462 // Start initial completion
4463 thread.update(cx, |thread, cx| {
4464 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4465 });
4466 cx.run_until_parked();
4467
4468 // Advance through all retries
4469 for _ in 0..MAX_RETRY_ATTEMPTS {
4470 cx.executor().advance_clock(BASE_RETRY_DELAY);
4471 cx.run_until_parked();
4472 }
4473
4474 let retry_count = thread.update(cx, |thread, _| {
4475 thread
4476 .messages
4477 .iter()
4478 .filter(|m| {
4479 m.ui_only
4480 && m.segments.iter().any(|s| {
4481 if let MessageSegment::Text(text) = s {
4482 text.contains("Retrying") && text.contains("seconds")
4483 } else {
4484 false
4485 }
4486 })
4487 })
4488 .count()
4489 });
4490
4491 // After max retries, should emit Stopped(Err(...)) event
4492 assert_eq!(
4493 retry_count, MAX_RETRY_ATTEMPTS as usize,
4494 "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4495 );
4496 assert!(
4497 *stopped_with_error.lock(),
4498 "Should emit Stopped(Err(...)) event after max retries exceeded"
4499 );
4500
4501 // Retry state should be cleared
4502 thread.read_with(cx, |thread, _| {
4503 assert!(
4504 thread.retry_state.is_none(),
4505 "Retry state should be cleared after max retries"
4506 );
4507
4508 // Verify we have the expected number of retry messages
4509 let retry_messages = thread
4510 .messages
4511 .iter()
4512 .filter(|msg| msg.ui_only && msg.role == Role::System)
4513 .count();
4514 assert_eq!(
4515 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4516 "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4517 );
4518 });
4519 }
4520
4521 #[gpui::test]
4522 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4523 let fs = init_test_settings(cx);
4524
4525 let project = create_test_project(&fs, cx, json!({})).await;
4526 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4527
4528 // Enable Burn Mode to allow retries
4529 thread.update(cx, |thread, _| {
4530 thread.set_completion_mode(CompletionMode::Burn);
4531 });
4532
4533 // We'll use a wrapper to switch behavior after first failure
4534 struct RetryTestModel {
4535 inner: Arc<FakeLanguageModel>,
4536 failed_once: Arc<Mutex<bool>>,
4537 }
4538
4539 impl LanguageModel for RetryTestModel {
4540 fn id(&self) -> LanguageModelId {
4541 self.inner.id()
4542 }
4543
4544 fn name(&self) -> LanguageModelName {
4545 self.inner.name()
4546 }
4547
4548 fn provider_id(&self) -> LanguageModelProviderId {
4549 self.inner.provider_id()
4550 }
4551
4552 fn provider_name(&self) -> LanguageModelProviderName {
4553 self.inner.provider_name()
4554 }
4555
4556 fn supports_tools(&self) -> bool {
4557 self.inner.supports_tools()
4558 }
4559
4560 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4561 self.inner.supports_tool_choice(choice)
4562 }
4563
4564 fn supports_images(&self) -> bool {
4565 self.inner.supports_images()
4566 }
4567
4568 fn telemetry_id(&self) -> String {
4569 self.inner.telemetry_id()
4570 }
4571
4572 fn max_token_count(&self) -> u64 {
4573 self.inner.max_token_count()
4574 }
4575
4576 fn count_tokens(
4577 &self,
4578 request: LanguageModelRequest,
4579 cx: &App,
4580 ) -> BoxFuture<'static, Result<u64>> {
4581 self.inner.count_tokens(request, cx)
4582 }
4583
4584 fn stream_completion(
4585 &self,
4586 request: LanguageModelRequest,
4587 cx: &AsyncApp,
4588 ) -> BoxFuture<
4589 'static,
4590 Result<
4591 BoxStream<
4592 'static,
4593 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4594 >,
4595 LanguageModelCompletionError,
4596 >,
4597 > {
4598 if !*self.failed_once.lock() {
4599 *self.failed_once.lock() = true;
4600 let provider = self.provider_name();
4601 // Return error on first attempt
4602 let stream = futures::stream::once(async move {
4603 Err(LanguageModelCompletionError::ServerOverloaded {
4604 provider,
4605 retry_after: None,
4606 })
4607 });
4608 async move { Ok(stream.boxed()) }.boxed()
4609 } else {
4610 // Succeed on retry
4611 self.inner.stream_completion(request, cx)
4612 }
4613 }
4614
4615 fn as_fake(&self) -> &FakeLanguageModel {
4616 &self.inner
4617 }
4618 }
4619
4620 let model = Arc::new(RetryTestModel {
4621 inner: Arc::new(FakeLanguageModel::default()),
4622 failed_once: Arc::new(Mutex::new(false)),
4623 });
4624
4625 // Insert a user message
4626 thread.update(cx, |thread, cx| {
4627 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4628 });
4629
4630 // Track message deletions
4631 // Track when retry completes successfully
4632 let retry_completed = Arc::new(Mutex::new(false));
4633 let retry_completed_clone = retry_completed.clone();
4634
4635 let _subscription = thread.update(cx, |_, cx| {
4636 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4637 if let ThreadEvent::StreamedCompletion = event {
4638 *retry_completed_clone.lock() = true;
4639 }
4640 })
4641 });
4642
4643 // Start completion
4644 thread.update(cx, |thread, cx| {
4645 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4646 });
4647 cx.run_until_parked();
4648
4649 // Get the retry message ID
4650 let retry_message_id = thread.read_with(cx, |thread, _| {
4651 thread
4652 .messages()
4653 .find(|msg| msg.role == Role::System && msg.ui_only)
4654 .map(|msg| msg.id)
4655 .expect("Should have a retry message")
4656 });
4657
4658 // Wait for retry
4659 cx.executor().advance_clock(BASE_RETRY_DELAY);
4660 cx.run_until_parked();
4661
4662 // Stream some successful content
4663 let fake_model = model.as_fake();
4664 // After the retry, there should be a new pending completion
4665 let pending = fake_model.pending_completions();
4666 assert!(
4667 !pending.is_empty(),
4668 "Should have a pending completion after retry"
4669 );
4670 fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
4671 fake_model.end_completion_stream(&pending[0]);
4672 cx.run_until_parked();
4673
4674 // Check that the retry completed successfully
4675 assert!(
4676 *retry_completed.lock(),
4677 "Retry should have completed successfully"
4678 );
4679
4680 // Retry message should still exist but be marked as ui_only
4681 thread.read_with(cx, |thread, _| {
4682 let retry_msg = thread
4683 .message(retry_message_id)
4684 .expect("Retry message should still exist");
4685 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4686 assert_eq!(
4687 retry_msg.role,
4688 Role::System,
4689 "Retry message should have System role"
4690 );
4691 });
4692 }
4693
4694 #[gpui::test]
4695 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4696 let fs = init_test_settings(cx);
4697
4698 let project = create_test_project(&fs, cx, json!({})).await;
4699 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4700
4701 // Enable Burn Mode to allow retries
4702 thread.update(cx, |thread, _| {
4703 thread.set_completion_mode(CompletionMode::Burn);
4704 });
4705
4706 // Create a model that fails once then succeeds
4707 struct FailOnceModel {
4708 inner: Arc<FakeLanguageModel>,
4709 failed_once: Arc<Mutex<bool>>,
4710 }
4711
4712 impl LanguageModel for FailOnceModel {
4713 fn id(&self) -> LanguageModelId {
4714 self.inner.id()
4715 }
4716
4717 fn name(&self) -> LanguageModelName {
4718 self.inner.name()
4719 }
4720
4721 fn provider_id(&self) -> LanguageModelProviderId {
4722 self.inner.provider_id()
4723 }
4724
4725 fn provider_name(&self) -> LanguageModelProviderName {
4726 self.inner.provider_name()
4727 }
4728
4729 fn supports_tools(&self) -> bool {
4730 self.inner.supports_tools()
4731 }
4732
4733 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4734 self.inner.supports_tool_choice(choice)
4735 }
4736
4737 fn supports_images(&self) -> bool {
4738 self.inner.supports_images()
4739 }
4740
4741 fn telemetry_id(&self) -> String {
4742 self.inner.telemetry_id()
4743 }
4744
4745 fn max_token_count(&self) -> u64 {
4746 self.inner.max_token_count()
4747 }
4748
4749 fn count_tokens(
4750 &self,
4751 request: LanguageModelRequest,
4752 cx: &App,
4753 ) -> BoxFuture<'static, Result<u64>> {
4754 self.inner.count_tokens(request, cx)
4755 }
4756
4757 fn stream_completion(
4758 &self,
4759 request: LanguageModelRequest,
4760 cx: &AsyncApp,
4761 ) -> BoxFuture<
4762 'static,
4763 Result<
4764 BoxStream<
4765 'static,
4766 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4767 >,
4768 LanguageModelCompletionError,
4769 >,
4770 > {
4771 if !*self.failed_once.lock() {
4772 *self.failed_once.lock() = true;
4773 let provider = self.provider_name();
4774 // Return error on first attempt
4775 let stream = futures::stream::once(async move {
4776 Err(LanguageModelCompletionError::ServerOverloaded {
4777 provider,
4778 retry_after: None,
4779 })
4780 });
4781 async move { Ok(stream.boxed()) }.boxed()
4782 } else {
4783 // Succeed on retry
4784 self.inner.stream_completion(request, cx)
4785 }
4786 }
4787 }
4788
4789 let fail_once_model = Arc::new(FailOnceModel {
4790 inner: Arc::new(FakeLanguageModel::default()),
4791 failed_once: Arc::new(Mutex::new(false)),
4792 });
4793
4794 // Insert a user message
4795 thread.update(cx, |thread, cx| {
4796 thread.insert_user_message(
4797 "Test message",
4798 ContextLoadResult::default(),
4799 None,
4800 vec![],
4801 cx,
4802 );
4803 });
4804
4805 // Start completion with fail-once model
4806 thread.update(cx, |thread, cx| {
4807 thread.send_to_model(
4808 fail_once_model.clone(),
4809 CompletionIntent::UserPrompt,
4810 None,
4811 cx,
4812 );
4813 });
4814
4815 cx.run_until_parked();
4816
4817 // Verify retry state exists after first failure
4818 thread.read_with(cx, |thread, _| {
4819 assert!(
4820 thread.retry_state.is_some(),
4821 "Should have retry state after failure"
4822 );
4823 });
4824
4825 // Wait for retry delay
4826 cx.executor().advance_clock(BASE_RETRY_DELAY);
4827 cx.run_until_parked();
4828
4829 // The retry should now use our FailOnceModel which should succeed
4830 // We need to help the FakeLanguageModel complete the stream
4831 let inner_fake = fail_once_model.inner.clone();
4832
4833 // Wait a bit for the retry to start
4834 cx.run_until_parked();
4835
4836 // Check for pending completions and complete them
4837 if let Some(pending) = inner_fake.pending_completions().first() {
4838 inner_fake.send_completion_stream_text_chunk(pending, "Success!");
4839 inner_fake.end_completion_stream(pending);
4840 }
4841 cx.run_until_parked();
4842
4843 thread.read_with(cx, |thread, _| {
4844 assert!(
4845 thread.retry_state.is_none(),
4846 "Retry state should be cleared after successful completion"
4847 );
4848
4849 let has_assistant_message = thread
4850 .messages
4851 .iter()
4852 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4853 assert!(
4854 has_assistant_message,
4855 "Should have an assistant message after successful retry"
4856 );
4857 });
4858 }
4859
4860 #[gpui::test]
4861 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4862 let fs = init_test_settings(cx);
4863
4864 let project = create_test_project(&fs, cx, json!({})).await;
4865 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4866
4867 // Enable Burn Mode to allow retries
4868 thread.update(cx, |thread, _| {
4869 thread.set_completion_mode(CompletionMode::Burn);
4870 });
4871
4872 // Create a model that returns rate limit error with retry_after
4873 struct RateLimitModel {
4874 inner: Arc<FakeLanguageModel>,
4875 }
4876
4877 impl LanguageModel for RateLimitModel {
4878 fn id(&self) -> LanguageModelId {
4879 self.inner.id()
4880 }
4881
4882 fn name(&self) -> LanguageModelName {
4883 self.inner.name()
4884 }
4885
4886 fn provider_id(&self) -> LanguageModelProviderId {
4887 self.inner.provider_id()
4888 }
4889
4890 fn provider_name(&self) -> LanguageModelProviderName {
4891 self.inner.provider_name()
4892 }
4893
4894 fn supports_tools(&self) -> bool {
4895 self.inner.supports_tools()
4896 }
4897
4898 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4899 self.inner.supports_tool_choice(choice)
4900 }
4901
4902 fn supports_images(&self) -> bool {
4903 self.inner.supports_images()
4904 }
4905
4906 fn telemetry_id(&self) -> String {
4907 self.inner.telemetry_id()
4908 }
4909
4910 fn max_token_count(&self) -> u64 {
4911 self.inner.max_token_count()
4912 }
4913
4914 fn count_tokens(
4915 &self,
4916 request: LanguageModelRequest,
4917 cx: &App,
4918 ) -> BoxFuture<'static, Result<u64>> {
4919 self.inner.count_tokens(request, cx)
4920 }
4921
4922 fn stream_completion(
4923 &self,
4924 _request: LanguageModelRequest,
4925 _cx: &AsyncApp,
4926 ) -> BoxFuture<
4927 'static,
4928 Result<
4929 BoxStream<
4930 'static,
4931 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4932 >,
4933 LanguageModelCompletionError,
4934 >,
4935 > {
4936 let provider = self.provider_name();
4937 async move {
4938 let stream = futures::stream::once(async move {
4939 Err(LanguageModelCompletionError::RateLimitExceeded {
4940 provider,
4941 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4942 })
4943 });
4944 Ok(stream.boxed())
4945 }
4946 .boxed()
4947 }
4948
4949 fn as_fake(&self) -> &FakeLanguageModel {
4950 &self.inner
4951 }
4952 }
4953
4954 let model = Arc::new(RateLimitModel {
4955 inner: Arc::new(FakeLanguageModel::default()),
4956 });
4957
4958 // Insert a user message
4959 thread.update(cx, |thread, cx| {
4960 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4961 });
4962
4963 // Start completion
4964 thread.update(cx, |thread, cx| {
4965 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4966 });
4967
4968 cx.run_until_parked();
4969
4970 let retry_count = thread.update(cx, |thread, _| {
4971 thread
4972 .messages
4973 .iter()
4974 .filter(|m| {
4975 m.ui_only
4976 && m.segments.iter().any(|s| {
4977 if let MessageSegment::Text(text) = s {
4978 text.contains("rate limit exceeded")
4979 } else {
4980 false
4981 }
4982 })
4983 })
4984 .count()
4985 });
4986 assert_eq!(retry_count, 1, "Should have scheduled one retry");
4987
4988 thread.read_with(cx, |thread, _| {
4989 assert!(
4990 thread.retry_state.is_some(),
4991 "Rate limit errors should set retry_state"
4992 );
4993 if let Some(retry_state) = &thread.retry_state {
4994 assert_eq!(
4995 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4996 "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4997 );
4998 }
4999 });
5000
5001 // Verify we have one retry message
5002 thread.read_with(cx, |thread, _| {
5003 let retry_messages = thread
5004 .messages
5005 .iter()
5006 .filter(|msg| {
5007 msg.ui_only
5008 && msg.segments.iter().any(|seg| {
5009 if let MessageSegment::Text(text) = seg {
5010 text.contains("rate limit exceeded")
5011 } else {
5012 false
5013 }
5014 })
5015 })
5016 .count();
5017 assert_eq!(
5018 retry_messages, 1,
5019 "Should have one rate limit retry message"
5020 );
5021 });
5022
5023 // Check that retry message doesn't include attempt count
5024 thread.read_with(cx, |thread, _| {
5025 let retry_message = thread
5026 .messages
5027 .iter()
5028 .find(|msg| msg.role == Role::System && msg.ui_only)
5029 .expect("Should have a retry message");
5030
5031 // Check that the message contains attempt count since we use retry_state
5032 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5033 assert!(
5034 text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
5035 "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
5036 );
5037 assert!(
5038 text.contains("Retrying"),
5039 "Rate limit retry message should contain retry text"
5040 );
5041 }
5042 });
5043 }
5044
5045 #[gpui::test]
5046 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5047 let fs = init_test_settings(cx);
5048
5049 let project = create_test_project(&fs, cx, json!({})).await;
5050 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5051
5052 // Insert a regular user message
5053 thread.update(cx, |thread, cx| {
5054 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5055 });
5056
5057 // Insert a UI-only message (like our retry notifications)
5058 thread.update(cx, |thread, cx| {
5059 let id = thread.next_message_id.post_inc();
5060 thread.messages.push(Message {
5061 id,
5062 role: Role::System,
5063 segments: vec![MessageSegment::Text(
5064 "This is a UI-only message that should not be sent to the model".to_string(),
5065 )],
5066 loaded_context: LoadedContext::default(),
5067 creases: Vec::new(),
5068 is_hidden: true,
5069 ui_only: true,
5070 });
5071 cx.emit(ThreadEvent::MessageAdded(id));
5072 });
5073
5074 // Insert another regular message
5075 thread.update(cx, |thread, cx| {
5076 thread.insert_user_message(
5077 "How are you?",
5078 ContextLoadResult::default(),
5079 None,
5080 vec![],
5081 cx,
5082 );
5083 });
5084
5085 // Generate the completion request
5086 let request = thread.update(cx, |thread, cx| {
5087 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5088 });
5089
5090 // Verify that the request only contains non-UI-only messages
5091 // Should have system prompt + 2 user messages, but not the UI-only message
5092 let user_messages: Vec<_> = request
5093 .messages
5094 .iter()
5095 .filter(|msg| msg.role == Role::User)
5096 .collect();
5097 assert_eq!(
5098 user_messages.len(),
5099 2,
5100 "Should have exactly 2 user messages"
5101 );
5102
5103 // Verify the UI-only content is not present anywhere in the request
5104 let request_text = request
5105 .messages
5106 .iter()
5107 .flat_map(|msg| &msg.content)
5108 .filter_map(|content| match content {
5109 MessageContent::Text(text) => Some(text.as_str()),
5110 _ => None,
5111 })
5112 .collect::<String>();
5113
5114 assert!(
5115 !request_text.contains("UI-only message"),
5116 "UI-only message content should not be in the request"
5117 );
5118
5119 // Verify the thread still has all 3 messages (including UI-only)
5120 thread.read_with(cx, |thread, _| {
5121 assert_eq!(
5122 thread.messages().count(),
5123 3,
5124 "Thread should have 3 messages"
5125 );
5126 assert_eq!(
5127 thread.messages().filter(|m| m.ui_only).count(),
5128 1,
5129 "Thread should have 1 UI-only message"
5130 );
5131 });
5132
5133 // Verify that UI-only messages are not serialized
5134 let serialized = thread
5135 .update(cx, |thread, cx| thread.serialize(cx))
5136 .await
5137 .unwrap();
5138 assert_eq!(
5139 serialized.messages.len(),
5140 2,
5141 "Serialized thread should only have 2 messages (no UI-only)"
5142 );
5143 }
5144
5145 #[gpui::test]
5146 async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
5147 let fs = init_test_settings(cx);
5148
5149 let project = create_test_project(&fs, cx, json!({})).await;
5150 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5151
5152 // Ensure we're in Normal mode (not Burn mode)
5153 thread.update(cx, |thread, _| {
5154 thread.set_completion_mode(CompletionMode::Normal);
5155 });
5156
5157 // Track error events
5158 let error_events = Arc::new(Mutex::new(Vec::new()));
5159 let error_events_clone = error_events.clone();
5160
5161 let _subscription = thread.update(cx, |_, cx| {
5162 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
5163 if let ThreadEvent::ShowError(error) = event {
5164 error_events_clone.lock().push(error.clone());
5165 }
5166 })
5167 });
5168
5169 // Create model that returns overloaded error
5170 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5171
5172 // Insert a user message
5173 thread.update(cx, |thread, cx| {
5174 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5175 });
5176
5177 // Start completion
5178 thread.update(cx, |thread, cx| {
5179 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5180 });
5181
5182 cx.run_until_parked();
5183
5184 // Verify no retry state was created
5185 thread.read_with(cx, |thread, _| {
5186 assert!(
5187 thread.retry_state.is_none(),
5188 "Should not have retry state in Normal mode"
5189 );
5190 });
5191
5192 // Check that a retryable error was reported
5193 let errors = error_events.lock();
5194 assert!(!errors.is_empty(), "Should have received an error event");
5195
5196 if let ThreadError::RetryableError {
5197 message: _,
5198 can_enable_burn_mode,
5199 } = &errors[0]
5200 {
5201 assert!(
5202 *can_enable_burn_mode,
5203 "Error should indicate burn mode can be enabled"
5204 );
5205 } else {
5206 panic!("Expected RetryableError, got {:?}", errors[0]);
5207 }
5208
5209 // Verify the thread is no longer generating
5210 thread.read_with(cx, |thread, _| {
5211 assert!(
5212 !thread.is_generating(),
5213 "Should not be generating after error without retry"
5214 );
5215 });
5216 }
5217
5218 #[gpui::test]
5219 async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
5220 let fs = init_test_settings(cx);
5221
5222 let project = create_test_project(&fs, cx, json!({})).await;
5223 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5224
5225 // Enable Burn Mode to allow retries
5226 thread.update(cx, |thread, _| {
5227 thread.set_completion_mode(CompletionMode::Burn);
5228 });
5229
5230 // Create model that returns overloaded error
5231 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5232
5233 // Insert a user message
5234 thread.update(cx, |thread, cx| {
5235 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5236 });
5237
5238 // Start completion
5239 thread.update(cx, |thread, cx| {
5240 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5241 });
5242
5243 cx.run_until_parked();
5244
5245 // Verify retry was scheduled by checking for retry message
5246 let has_retry_message = thread.read_with(cx, |thread, _| {
5247 thread.messages.iter().any(|m| {
5248 m.ui_only
5249 && m.segments.iter().any(|s| {
5250 if let MessageSegment::Text(text) = s {
5251 text.contains("Retrying") && text.contains("seconds")
5252 } else {
5253 false
5254 }
5255 })
5256 })
5257 });
5258 assert!(has_retry_message, "Should have scheduled a retry");
5259
5260 // Cancel the completion before the retry happens
5261 thread.update(cx, |thread, cx| {
5262 thread.cancel_last_completion(None, cx);
5263 });
5264
5265 cx.run_until_parked();
5266
5267 // The retry should not have happened - no pending completions
5268 let fake_model = model.as_fake();
5269 assert_eq!(
5270 fake_model.pending_completions().len(),
5271 0,
5272 "Should have no pending completions after cancellation"
5273 );
5274
5275 // Verify the retry was canceled by checking retry state
5276 thread.read_with(cx, |thread, _| {
5277 if let Some(retry_state) = &thread.retry_state {
5278 panic!(
5279 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5280 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5281 );
5282 }
5283 });
5284 }
5285
5286 fn test_summarize_error(
5287 model: &Arc<dyn LanguageModel>,
5288 thread: &Entity<Thread>,
5289 cx: &mut TestAppContext,
5290 ) {
5291 thread.update(cx, |thread, cx| {
5292 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5293 thread.send_to_model(
5294 model.clone(),
5295 CompletionIntent::ThreadSummarization,
5296 None,
5297 cx,
5298 );
5299 });
5300
5301 let fake_model = model.as_fake();
5302 simulate_successful_response(fake_model, cx);
5303
5304 thread.read_with(cx, |thread, _| {
5305 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5306 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5307 });
5308
5309 // Simulate summary request ending
5310 cx.run_until_parked();
5311 fake_model.end_last_completion_stream();
5312 cx.run_until_parked();
5313
5314 // State is set to Error and default message
5315 thread.read_with(cx, |thread, _| {
5316 assert!(matches!(thread.summary(), ThreadSummary::Error));
5317 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5318 });
5319 }
5320
5321 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5322 cx.run_until_parked();
5323 fake_model.send_last_completion_stream_text_chunk("Assistant response");
5324 fake_model.end_last_completion_stream();
5325 cx.run_until_parked();
5326 }
5327
5328 fn init_test_settings(cx: &mut TestAppContext) -> Arc<dyn Fs> {
5329 let fs = FakeFs::new(cx.executor());
5330 cx.update(|cx| {
5331 let settings_store = SettingsStore::test(cx);
5332 cx.set_global(settings_store);
5333 language::init(cx);
5334 Project::init_settings(cx);
5335 AgentSettings::register(cx);
5336 prompt_store::init(cx);
5337 thread_store::init(fs.clone(), cx);
5338 workspace::init_settings(cx);
5339 language_model::init_settings(cx);
5340 ThemeSettings::register(cx);
5341 ToolRegistry::default_global(cx);
5342 assistant_tool::init(cx);
5343
5344 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5345 http_client::FakeHttpClient::with_200_response(),
5346 "http://localhost".to_string(),
5347 None,
5348 ));
5349 assistant_tools::init(http_client, cx);
5350 });
5351 fs
5352 }
5353
5354 // Helper to create a test project with test files
5355 async fn create_test_project(
5356 fs: &Arc<dyn Fs>,
5357 cx: &mut TestAppContext,
5358 files: serde_json::Value,
5359 ) -> Entity<Project> {
5360 fs.as_fake().insert_tree(path!("/test"), files).await;
5361 Project::test(fs.clone(), [path!("/test").as_ref()], cx).await
5362 }
5363
5364 async fn setup_test_environment(
5365 cx: &mut TestAppContext,
5366 project: Entity<Project>,
5367 ) -> (
5368 Entity<Workspace>,
5369 Entity<ThreadStore>,
5370 Entity<Thread>,
5371 Entity<ContextStore>,
5372 Arc<dyn LanguageModel>,
5373 ) {
5374 let (workspace, cx) =
5375 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5376
5377 let thread_store = cx
5378 .update(|_, cx| {
5379 ThreadStore::load(
5380 project.clone(),
5381 cx.new(|_| ToolWorkingSet::default()),
5382 None,
5383 Arc::new(PromptBuilder::new(None).unwrap()),
5384 cx,
5385 )
5386 })
5387 .await
5388 .unwrap();
5389
5390 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5391 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5392
5393 let provider = Arc::new(FakeLanguageModelProvider::default());
5394 let model = provider.test_model();
5395 let model: Arc<dyn LanguageModel> = Arc::new(model);
5396
5397 cx.update(|_, cx| {
5398 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5399 registry.set_default_model(
5400 Some(ConfiguredModel {
5401 provider: provider.clone(),
5402 model: model.clone(),
5403 }),
5404 cx,
5405 );
5406 registry.set_thread_summary_model(
5407 Some(ConfiguredModel {
5408 provider,
5409 model: model.clone(),
5410 }),
5411 cx,
5412 );
5413 })
5414 });
5415
5416 (workspace, thread_store, thread, context_store, model)
5417 }
5418
5419 async fn add_file_to_context(
5420 project: &Entity<Project>,
5421 context_store: &Entity<ContextStore>,
5422 path: &str,
5423 cx: &mut TestAppContext,
5424 ) -> Result<Entity<language::Buffer>> {
5425 let buffer_path = project
5426 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5427 .unwrap();
5428
5429 let buffer = project
5430 .update(cx, |project, cx| {
5431 project.open_buffer(buffer_path.clone(), cx)
5432 })
5433 .await
5434 .unwrap();
5435
5436 context_store.update(cx, |context_store, cx| {
5437 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5438 });
5439
5440 Ok(buffer)
5441 }
5442}