1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use action_log::ActionLog;
12use agent_settings::{
13 AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
14 SUMMARIZE_THREAD_PROMPT,
15};
16use anyhow::{Result, anyhow};
17use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
18use chrono::{DateTime, Utc};
19use client::{ModelRequestUsage, RequestUsage};
20use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
21use collections::HashMap;
22use futures::{FutureExt, StreamExt as _, future::Shared};
23use git::repository::DiffType;
24use gpui::{
25 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
26 WeakEntity, Window,
27};
28use http_client::StatusCode;
29use language_model::{
30 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
31 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
32 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
33 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
34 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
35 TokenUsage,
36};
37use postage::stream::Stream as _;
38use project::{
39 Project,
40 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
41};
42use prompt_store::{ModelContext, PromptBuilder};
43use schemars::JsonSchema;
44use serde::{Deserialize, Serialize};
45use settings::Settings;
46use std::{
47 io::Write,
48 ops::Range,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use thiserror::Error;
53use util::{ResultExt as _, post_inc};
54use uuid::Uuid;
55
56const MAX_RETRY_ATTEMPTS: u8 = 4;
57const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
58
59#[derive(Debug, Clone)]
60enum RetryStrategy {
61 ExponentialBackoff {
62 initial_delay: Duration,
63 max_attempts: u8,
64 },
65 Fixed {
66 delay: Duration,
67 max_attempts: u8,
68 },
69}
70
71#[derive(
72 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
73)]
74pub struct ThreadId(Arc<str>);
75
76impl ThreadId {
77 pub fn new() -> Self {
78 Self(Uuid::new_v4().to_string().into())
79 }
80}
81
82impl std::fmt::Display for ThreadId {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.0)
85 }
86}
87
88impl From<&str> for ThreadId {
89 fn from(value: &str) -> Self {
90 Self(value.into())
91 }
92}
93
94/// The ID of the user prompt that initiated a request.
95///
96/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
98pub struct PromptId(Arc<str>);
99
100impl PromptId {
101 pub fn new() -> Self {
102 Self(Uuid::new_v4().to_string().into())
103 }
104}
105
106impl std::fmt::Display for PromptId {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
113pub struct MessageId(pub usize);
114
115impl MessageId {
116 fn post_inc(&mut self) -> Self {
117 Self(post_inc(&mut self.0))
118 }
119
120 pub fn as_usize(&self) -> usize {
121 self.0
122 }
123}
124
125/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
126#[derive(Clone, Debug)]
127pub struct MessageCrease {
128 pub range: Range<usize>,
129 pub icon_path: SharedString,
130 pub label: SharedString,
131 /// None for a deserialized message, Some otherwise.
132 pub context: Option<AgentContextHandle>,
133}
134
135/// A message in a [`Thread`].
136#[derive(Debug, Clone)]
137pub struct Message {
138 pub id: MessageId,
139 pub role: Role,
140 pub segments: Vec<MessageSegment>,
141 pub loaded_context: LoadedContext,
142 pub creases: Vec<MessageCrease>,
143 pub is_hidden: bool,
144 pub ui_only: bool,
145}
146
147impl Message {
148 /// Returns whether the message contains any meaningful text that should be displayed
149 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
150 pub fn should_display_content(&self) -> bool {
151 self.segments.iter().all(|segment| segment.should_display())
152 }
153
154 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
155 if let Some(MessageSegment::Thinking {
156 text: segment,
157 signature: current_signature,
158 }) = self.segments.last_mut()
159 {
160 if let Some(signature) = signature {
161 *current_signature = Some(signature);
162 }
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Thinking {
166 text: text.to_string(),
167 signature,
168 });
169 }
170 }
171
172 pub fn push_redacted_thinking(&mut self, data: String) {
173 self.segments.push(MessageSegment::RedactedThinking(data));
174 }
175
176 pub fn push_text(&mut self, text: &str) {
177 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
178 segment.push_str(text);
179 } else {
180 self.segments.push(MessageSegment::Text(text.to_string()));
181 }
182 }
183
184 pub fn to_string(&self) -> String {
185 let mut result = String::new();
186
187 if !self.loaded_context.text.is_empty() {
188 result.push_str(&self.loaded_context.text);
189 }
190
191 for segment in &self.segments {
192 match segment {
193 MessageSegment::Text(text) => result.push_str(text),
194 MessageSegment::Thinking { text, .. } => {
195 result.push_str("<think>\n");
196 result.push_str(text);
197 result.push_str("\n</think>");
198 }
199 MessageSegment::RedactedThinking(_) => {}
200 }
201 }
202
203 result
204 }
205}
206
207#[derive(Debug, Clone, PartialEq, Eq)]
208pub enum MessageSegment {
209 Text(String),
210 Thinking {
211 text: String,
212 signature: Option<String>,
213 },
214 RedactedThinking(String),
215}
216
217impl MessageSegment {
218 pub fn should_display(&self) -> bool {
219 match self {
220 Self::Text(text) => text.is_empty(),
221 Self::Thinking { text, .. } => text.is_empty(),
222 Self::RedactedThinking(_) => false,
223 }
224 }
225
226 pub fn text(&self) -> Option<&str> {
227 match self {
228 MessageSegment::Text(text) => Some(text),
229 _ => None,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
235pub struct ProjectSnapshot {
236 pub worktree_snapshots: Vec<WorktreeSnapshot>,
237 pub unsaved_buffer_paths: Vec<String>,
238 pub timestamp: DateTime<Utc>,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
242pub struct WorktreeSnapshot {
243 pub worktree_path: String,
244 pub git_state: Option<GitState>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
248pub struct GitState {
249 pub remote_url: Option<String>,
250 pub head_sha: Option<String>,
251 pub current_branch: Option<String>,
252 pub diff: Option<String>,
253}
254
255#[derive(Clone, Debug)]
256pub struct ThreadCheckpoint {
257 message_id: MessageId,
258 git_checkpoint: GitStoreCheckpoint,
259}
260
261#[derive(Copy, Clone, Debug, PartialEq, Eq)]
262pub enum ThreadFeedback {
263 Positive,
264 Negative,
265}
266
267pub enum LastRestoreCheckpoint {
268 Pending {
269 message_id: MessageId,
270 },
271 Error {
272 message_id: MessageId,
273 error: String,
274 },
275}
276
277impl LastRestoreCheckpoint {
278 pub fn message_id(&self) -> MessageId {
279 match self {
280 LastRestoreCheckpoint::Pending { message_id } => *message_id,
281 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
282 }
283 }
284}
285
286#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
287pub enum DetailedSummaryState {
288 #[default]
289 NotGenerated,
290 Generating {
291 message_id: MessageId,
292 },
293 Generated {
294 text: SharedString,
295 message_id: MessageId,
296 },
297}
298
299impl DetailedSummaryState {
300 fn text(&self) -> Option<SharedString> {
301 if let Self::Generated { text, .. } = self {
302 Some(text.clone())
303 } else {
304 None
305 }
306 }
307}
308
309#[derive(Default, Debug)]
310pub struct TotalTokenUsage {
311 pub total: u64,
312 pub max: u64,
313}
314
315impl TotalTokenUsage {
316 pub fn ratio(&self) -> TokenUsageRatio {
317 #[cfg(debug_assertions)]
318 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
319 .unwrap_or("0.8".to_string())
320 .parse()
321 .unwrap();
322 #[cfg(not(debug_assertions))]
323 let warning_threshold: f32 = 0.8;
324
325 // When the maximum is unknown because there is no selected model,
326 // avoid showing the token limit warning.
327 if self.max == 0 {
328 TokenUsageRatio::Normal
329 } else if self.total >= self.max {
330 TokenUsageRatio::Exceeded
331 } else if self.total as f32 / self.max as f32 >= warning_threshold {
332 TokenUsageRatio::Warning
333 } else {
334 TokenUsageRatio::Normal
335 }
336 }
337
338 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
339 TotalTokenUsage {
340 total: self.total + tokens,
341 max: self.max,
342 }
343 }
344}
345
346#[derive(Debug, Default, PartialEq, Eq)]
347pub enum TokenUsageRatio {
348 #[default]
349 Normal,
350 Warning,
351 Exceeded,
352}
353
354#[derive(Debug, Clone, Copy)]
355pub enum QueueState {
356 Sending,
357 Queued { position: usize },
358 Started,
359}
360
361/// A thread of conversation with the LLM.
362pub struct Thread {
363 id: ThreadId,
364 updated_at: DateTime<Utc>,
365 summary: ThreadSummary,
366 pending_summary: Task<Option<()>>,
367 detailed_summary_task: Task<Option<()>>,
368 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
369 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
370 completion_mode: agent_settings::CompletionMode,
371 messages: Vec<Message>,
372 next_message_id: MessageId,
373 last_prompt_id: PromptId,
374 project_context: SharedProjectContext,
375 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
376 completion_count: usize,
377 pending_completions: Vec<PendingCompletion>,
378 project: Entity<Project>,
379 prompt_builder: Arc<PromptBuilder>,
380 tools: Entity<ToolWorkingSet>,
381 tool_use: ToolUseState,
382 action_log: Entity<ActionLog>,
383 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
384 pending_checkpoint: Option<ThreadCheckpoint>,
385 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
386 request_token_usage: Vec<TokenUsage>,
387 cumulative_token_usage: TokenUsage,
388 exceeded_window_error: Option<ExceededWindowError>,
389 tool_use_limit_reached: bool,
390 feedback: Option<ThreadFeedback>,
391 retry_state: Option<RetryState>,
392 message_feedback: HashMap<MessageId, ThreadFeedback>,
393 last_received_chunk_at: Option<Instant>,
394 request_callback: Option<
395 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
396 >,
397 remaining_turns: u32,
398 configured_model: Option<ConfiguredModel>,
399 profile: AgentProfile,
400 last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
401}
402
403#[derive(Clone, Debug)]
404struct RetryState {
405 attempt: u8,
406 max_attempts: u8,
407 intent: CompletionIntent,
408}
409
410#[derive(Clone, Debug, PartialEq, Eq)]
411pub enum ThreadSummary {
412 Pending,
413 Generating,
414 Ready(SharedString),
415 Error,
416}
417
418impl ThreadSummary {
419 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
420
421 pub fn or_default(&self) -> SharedString {
422 self.unwrap_or(Self::DEFAULT)
423 }
424
425 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
426 self.ready().unwrap_or_else(|| message.into())
427 }
428
429 pub fn ready(&self) -> Option<SharedString> {
430 match self {
431 ThreadSummary::Ready(summary) => Some(summary.clone()),
432 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
433 }
434 }
435}
436
437#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
438pub struct ExceededWindowError {
439 /// Model used when last message exceeded context window
440 model_id: LanguageModelId,
441 /// Token count including last message
442 token_count: u64,
443}
444
445impl Thread {
446 pub fn new(
447 project: Entity<Project>,
448 tools: Entity<ToolWorkingSet>,
449 prompt_builder: Arc<PromptBuilder>,
450 system_prompt: SharedProjectContext,
451 cx: &mut Context<Self>,
452 ) -> Self {
453 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
454 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
455 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
456
457 Self {
458 id: ThreadId::new(),
459 updated_at: Utc::now(),
460 summary: ThreadSummary::Pending,
461 pending_summary: Task::ready(None),
462 detailed_summary_task: Task::ready(None),
463 detailed_summary_tx,
464 detailed_summary_rx,
465 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
466 messages: Vec::new(),
467 next_message_id: MessageId(0),
468 last_prompt_id: PromptId::new(),
469 project_context: system_prompt,
470 checkpoints_by_message: HashMap::default(),
471 completion_count: 0,
472 pending_completions: Vec::new(),
473 project: project.clone(),
474 prompt_builder,
475 tools: tools.clone(),
476 last_restore_checkpoint: None,
477 pending_checkpoint: None,
478 tool_use: ToolUseState::new(tools.clone()),
479 action_log: cx.new(|_| ActionLog::new(project.clone())),
480 initial_project_snapshot: {
481 let project_snapshot = Self::project_snapshot(project, cx);
482 cx.foreground_executor()
483 .spawn(async move { Some(project_snapshot.await) })
484 .shared()
485 },
486 request_token_usage: Vec::new(),
487 cumulative_token_usage: TokenUsage::default(),
488 exceeded_window_error: None,
489 tool_use_limit_reached: false,
490 feedback: None,
491 retry_state: None,
492 message_feedback: HashMap::default(),
493 last_error_context: None,
494 last_received_chunk_at: None,
495 request_callback: None,
496 remaining_turns: u32::MAX,
497 configured_model: configured_model.clone(),
498 profile: AgentProfile::new(profile_id, tools),
499 }
500 }
501
502 pub fn deserialize(
503 id: ThreadId,
504 serialized: SerializedThread,
505 project: Entity<Project>,
506 tools: Entity<ToolWorkingSet>,
507 prompt_builder: Arc<PromptBuilder>,
508 project_context: SharedProjectContext,
509 window: Option<&mut Window>, // None in headless mode
510 cx: &mut Context<Self>,
511 ) -> Self {
512 let next_message_id = MessageId(
513 serialized
514 .messages
515 .last()
516 .map(|message| message.id.0 + 1)
517 .unwrap_or(0),
518 );
519 let tool_use = ToolUseState::from_serialized_messages(
520 tools.clone(),
521 &serialized.messages,
522 project.clone(),
523 window,
524 cx,
525 );
526 let (detailed_summary_tx, detailed_summary_rx) =
527 postage::watch::channel_with(serialized.detailed_summary_state);
528
529 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
530 serialized
531 .model
532 .and_then(|model| {
533 let model = SelectedModel {
534 provider: model.provider.clone().into(),
535 model: model.model.clone().into(),
536 };
537 registry.select_model(&model, cx)
538 })
539 .or_else(|| registry.default_model())
540 });
541
542 let completion_mode = serialized
543 .completion_mode
544 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
545 let profile_id = serialized
546 .profile
547 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
548
549 Self {
550 id,
551 updated_at: serialized.updated_at,
552 summary: ThreadSummary::Ready(serialized.summary),
553 pending_summary: Task::ready(None),
554 detailed_summary_task: Task::ready(None),
555 detailed_summary_tx,
556 detailed_summary_rx,
557 completion_mode,
558 retry_state: None,
559 messages: serialized
560 .messages
561 .into_iter()
562 .map(|message| Message {
563 id: message.id,
564 role: message.role,
565 segments: message
566 .segments
567 .into_iter()
568 .map(|segment| match segment {
569 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
570 SerializedMessageSegment::Thinking { text, signature } => {
571 MessageSegment::Thinking { text, signature }
572 }
573 SerializedMessageSegment::RedactedThinking { data } => {
574 MessageSegment::RedactedThinking(data)
575 }
576 })
577 .collect(),
578 loaded_context: LoadedContext {
579 contexts: Vec::new(),
580 text: message.context,
581 images: Vec::new(),
582 },
583 creases: message
584 .creases
585 .into_iter()
586 .map(|crease| MessageCrease {
587 range: crease.start..crease.end,
588 icon_path: crease.icon_path,
589 label: crease.label,
590 context: None,
591 })
592 .collect(),
593 is_hidden: message.is_hidden,
594 ui_only: false, // UI-only messages are not persisted
595 })
596 .collect(),
597 next_message_id,
598 last_prompt_id: PromptId::new(),
599 project_context,
600 checkpoints_by_message: HashMap::default(),
601 completion_count: 0,
602 pending_completions: Vec::new(),
603 last_restore_checkpoint: None,
604 pending_checkpoint: None,
605 project: project.clone(),
606 prompt_builder,
607 tools: tools.clone(),
608 tool_use,
609 action_log: cx.new(|_| ActionLog::new(project)),
610 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
611 request_token_usage: serialized.request_token_usage,
612 cumulative_token_usage: serialized.cumulative_token_usage,
613 exceeded_window_error: None,
614 tool_use_limit_reached: serialized.tool_use_limit_reached,
615 feedback: None,
616 message_feedback: HashMap::default(),
617 last_error_context: None,
618 last_received_chunk_at: None,
619 request_callback: None,
620 remaining_turns: u32::MAX,
621 configured_model,
622 profile: AgentProfile::new(profile_id, tools),
623 }
624 }
625
626 pub fn set_request_callback(
627 &mut self,
628 callback: impl 'static
629 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
630 ) {
631 self.request_callback = Some(Box::new(callback));
632 }
633
634 pub fn id(&self) -> &ThreadId {
635 &self.id
636 }
637
638 pub fn profile(&self) -> &AgentProfile {
639 &self.profile
640 }
641
642 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
643 if &id != self.profile.id() {
644 self.profile = AgentProfile::new(id, self.tools.clone());
645 cx.emit(ThreadEvent::ProfileChanged);
646 }
647 }
648
649 pub fn is_empty(&self) -> bool {
650 self.messages.is_empty()
651 }
652
653 pub fn updated_at(&self) -> DateTime<Utc> {
654 self.updated_at
655 }
656
657 pub fn touch_updated_at(&mut self) {
658 self.updated_at = Utc::now();
659 }
660
661 pub fn advance_prompt_id(&mut self) {
662 self.last_prompt_id = PromptId::new();
663 }
664
665 pub fn project_context(&self) -> SharedProjectContext {
666 self.project_context.clone()
667 }
668
669 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
670 if self.configured_model.is_none() {
671 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
672 }
673 self.configured_model.clone()
674 }
675
676 pub fn configured_model(&self) -> Option<ConfiguredModel> {
677 self.configured_model.clone()
678 }
679
680 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
681 self.configured_model = model;
682 cx.notify();
683 }
684
685 pub fn summary(&self) -> &ThreadSummary {
686 &self.summary
687 }
688
689 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
690 let current_summary = match &self.summary {
691 ThreadSummary::Pending | ThreadSummary::Generating => return,
692 ThreadSummary::Ready(summary) => summary,
693 ThreadSummary::Error => &ThreadSummary::DEFAULT,
694 };
695
696 let mut new_summary = new_summary.into();
697
698 if new_summary.is_empty() {
699 new_summary = ThreadSummary::DEFAULT;
700 }
701
702 if current_summary != &new_summary {
703 self.summary = ThreadSummary::Ready(new_summary);
704 cx.emit(ThreadEvent::SummaryChanged);
705 }
706 }
707
708 pub fn completion_mode(&self) -> CompletionMode {
709 self.completion_mode
710 }
711
712 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
713 self.completion_mode = mode;
714 }
715
716 pub fn message(&self, id: MessageId) -> Option<&Message> {
717 let index = self
718 .messages
719 .binary_search_by(|message| message.id.cmp(&id))
720 .ok()?;
721
722 self.messages.get(index)
723 }
724
725 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
726 self.messages.iter()
727 }
728
729 pub fn is_generating(&self) -> bool {
730 !self.pending_completions.is_empty() || !self.all_tools_finished()
731 }
732
733 /// Indicates whether streaming of language model events is stale.
734 /// When `is_generating()` is false, this method returns `None`.
735 pub fn is_generation_stale(&self) -> Option<bool> {
736 const STALE_THRESHOLD: u128 = 250;
737
738 self.last_received_chunk_at
739 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
740 }
741
742 fn received_chunk(&mut self) {
743 self.last_received_chunk_at = Some(Instant::now());
744 }
745
746 pub fn queue_state(&self) -> Option<QueueState> {
747 self.pending_completions
748 .first()
749 .map(|pending_completion| pending_completion.queue_state)
750 }
751
752 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
753 &self.tools
754 }
755
756 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
757 self.tool_use
758 .pending_tool_uses()
759 .into_iter()
760 .find(|tool_use| &tool_use.id == id)
761 }
762
763 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
764 self.tool_use
765 .pending_tool_uses()
766 .into_iter()
767 .filter(|tool_use| tool_use.status.needs_confirmation())
768 }
769
770 pub fn has_pending_tool_uses(&self) -> bool {
771 !self.tool_use.pending_tool_uses().is_empty()
772 }
773
774 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
775 self.checkpoints_by_message.get(&id).cloned()
776 }
777
778 pub fn restore_checkpoint(
779 &mut self,
780 checkpoint: ThreadCheckpoint,
781 cx: &mut Context<Self>,
782 ) -> Task<Result<()>> {
783 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
784 message_id: checkpoint.message_id,
785 });
786 cx.emit(ThreadEvent::CheckpointChanged);
787 cx.notify();
788
789 let git_store = self.project().read(cx).git_store().clone();
790 let restore = git_store.update(cx, |git_store, cx| {
791 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
792 });
793
794 cx.spawn(async move |this, cx| {
795 let result = restore.await;
796 this.update(cx, |this, cx| {
797 if let Err(err) = result.as_ref() {
798 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
799 message_id: checkpoint.message_id,
800 error: err.to_string(),
801 });
802 } else {
803 this.truncate(checkpoint.message_id, cx);
804 this.last_restore_checkpoint = None;
805 }
806 this.pending_checkpoint = None;
807 cx.emit(ThreadEvent::CheckpointChanged);
808 cx.notify();
809 })?;
810 result
811 })
812 }
813
814 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
815 let pending_checkpoint = if self.is_generating() {
816 return;
817 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
818 checkpoint
819 } else {
820 return;
821 };
822
823 self.finalize_checkpoint(pending_checkpoint, cx);
824 }
825
826 fn finalize_checkpoint(
827 &mut self,
828 pending_checkpoint: ThreadCheckpoint,
829 cx: &mut Context<Self>,
830 ) {
831 let git_store = self.project.read(cx).git_store().clone();
832 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
833 cx.spawn(async move |this, cx| match final_checkpoint.await {
834 Ok(final_checkpoint) => {
835 let equal = git_store
836 .update(cx, |store, cx| {
837 store.compare_checkpoints(
838 pending_checkpoint.git_checkpoint.clone(),
839 final_checkpoint.clone(),
840 cx,
841 )
842 })?
843 .await
844 .unwrap_or(false);
845
846 this.update(cx, |this, cx| {
847 this.pending_checkpoint = if equal {
848 Some(pending_checkpoint)
849 } else {
850 this.insert_checkpoint(pending_checkpoint, cx);
851 Some(ThreadCheckpoint {
852 message_id: this.next_message_id,
853 git_checkpoint: final_checkpoint,
854 })
855 }
856 })?;
857
858 Ok(())
859 }
860 Err(_) => this.update(cx, |this, cx| {
861 this.insert_checkpoint(pending_checkpoint, cx)
862 }),
863 })
864 .detach();
865 }
866
867 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
868 self.checkpoints_by_message
869 .insert(checkpoint.message_id, checkpoint);
870 cx.emit(ThreadEvent::CheckpointChanged);
871 cx.notify();
872 }
873
874 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
875 self.last_restore_checkpoint.as_ref()
876 }
877
878 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
879 let Some(message_ix) = self
880 .messages
881 .iter()
882 .rposition(|message| message.id == message_id)
883 else {
884 return;
885 };
886 for deleted_message in self.messages.drain(message_ix..) {
887 self.checkpoints_by_message.remove(&deleted_message.id);
888 }
889 cx.notify();
890 }
891
892 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
893 self.messages
894 .iter()
895 .find(|message| message.id == id)
896 .into_iter()
897 .flat_map(|message| message.loaded_context.contexts.iter())
898 }
899
900 pub fn is_turn_end(&self, ix: usize) -> bool {
901 if self.messages.is_empty() {
902 return false;
903 }
904
905 if !self.is_generating() && ix == self.messages.len() - 1 {
906 return true;
907 }
908
909 let Some(message) = self.messages.get(ix) else {
910 return false;
911 };
912
913 if message.role != Role::Assistant {
914 return false;
915 }
916
917 self.messages
918 .get(ix + 1)
919 .and_then(|message| {
920 self.message(message.id)
921 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
922 })
923 .unwrap_or(false)
924 }
925
926 pub fn tool_use_limit_reached(&self) -> bool {
927 self.tool_use_limit_reached
928 }
929
930 /// Returns whether all of the tool uses have finished running.
931 pub fn all_tools_finished(&self) -> bool {
932 // If the only pending tool uses left are the ones with errors, then
933 // that means that we've finished running all of the pending tools.
934 self.tool_use
935 .pending_tool_uses()
936 .iter()
937 .all(|pending_tool_use| pending_tool_use.status.is_error())
938 }
939
940 /// Returns whether any pending tool uses may perform edits
941 pub fn has_pending_edit_tool_uses(&self) -> bool {
942 self.tool_use
943 .pending_tool_uses()
944 .iter()
945 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
946 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
947 }
948
949 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
950 self.tool_use.tool_uses_for_message(id, &self.project, cx)
951 }
952
953 pub fn tool_results_for_message(
954 &self,
955 assistant_message_id: MessageId,
956 ) -> Vec<&LanguageModelToolResult> {
957 self.tool_use.tool_results_for_message(assistant_message_id)
958 }
959
960 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
961 self.tool_use.tool_result(id)
962 }
963
964 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
965 match &self.tool_use.tool_result(id)?.content {
966 LanguageModelToolResultContent::Text(text) => Some(text),
967 LanguageModelToolResultContent::Image(_) => {
968 // TODO: We should display image
969 None
970 }
971 }
972 }
973
974 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
975 self.tool_use.tool_result_card(id).cloned()
976 }
977
978 /// Return tools that are both enabled and supported by the model
979 pub fn available_tools(
980 &self,
981 cx: &App,
982 model: Arc<dyn LanguageModel>,
983 ) -> Vec<LanguageModelRequestTool> {
984 if model.supports_tools() {
985 self.profile
986 .enabled_tools(cx)
987 .into_iter()
988 .filter_map(|(name, tool)| {
989 // Skip tools that cannot be supported
990 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
991 Some(LanguageModelRequestTool {
992 name: name.into(),
993 description: tool.description(),
994 input_schema,
995 })
996 })
997 .collect()
998 } else {
999 Vec::default()
1000 }
1001 }
1002
1003 pub fn insert_user_message(
1004 &mut self,
1005 text: impl Into<String>,
1006 loaded_context: ContextLoadResult,
1007 git_checkpoint: Option<GitStoreCheckpoint>,
1008 creases: Vec<MessageCrease>,
1009 cx: &mut Context<Self>,
1010 ) -> MessageId {
1011 if !loaded_context.referenced_buffers.is_empty() {
1012 self.action_log.update(cx, |log, cx| {
1013 for buffer in loaded_context.referenced_buffers {
1014 log.buffer_read(buffer, cx);
1015 }
1016 });
1017 }
1018
1019 let message_id = self.insert_message(
1020 Role::User,
1021 vec![MessageSegment::Text(text.into())],
1022 loaded_context.loaded_context,
1023 creases,
1024 false,
1025 cx,
1026 );
1027
1028 if let Some(git_checkpoint) = git_checkpoint {
1029 self.pending_checkpoint = Some(ThreadCheckpoint {
1030 message_id,
1031 git_checkpoint,
1032 });
1033 }
1034
1035 message_id
1036 }
1037
1038 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1039 let id = self.insert_message(
1040 Role::User,
1041 vec![MessageSegment::Text("Continue where you left off".into())],
1042 LoadedContext::default(),
1043 vec![],
1044 true,
1045 cx,
1046 );
1047 self.pending_checkpoint = None;
1048
1049 id
1050 }
1051
1052 pub fn insert_assistant_message(
1053 &mut self,
1054 segments: Vec<MessageSegment>,
1055 cx: &mut Context<Self>,
1056 ) -> MessageId {
1057 self.insert_message(
1058 Role::Assistant,
1059 segments,
1060 LoadedContext::default(),
1061 Vec::new(),
1062 false,
1063 cx,
1064 )
1065 }
1066
1067 pub fn insert_message(
1068 &mut self,
1069 role: Role,
1070 segments: Vec<MessageSegment>,
1071 loaded_context: LoadedContext,
1072 creases: Vec<MessageCrease>,
1073 is_hidden: bool,
1074 cx: &mut Context<Self>,
1075 ) -> MessageId {
1076 let id = self.next_message_id.post_inc();
1077 self.messages.push(Message {
1078 id,
1079 role,
1080 segments,
1081 loaded_context,
1082 creases,
1083 is_hidden,
1084 ui_only: false,
1085 });
1086 self.touch_updated_at();
1087 cx.emit(ThreadEvent::MessageAdded(id));
1088 id
1089 }
1090
1091 pub fn edit_message(
1092 &mut self,
1093 id: MessageId,
1094 new_role: Role,
1095 new_segments: Vec<MessageSegment>,
1096 creases: Vec<MessageCrease>,
1097 loaded_context: Option<LoadedContext>,
1098 checkpoint: Option<GitStoreCheckpoint>,
1099 cx: &mut Context<Self>,
1100 ) -> bool {
1101 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1102 return false;
1103 };
1104 message.role = new_role;
1105 message.segments = new_segments;
1106 message.creases = creases;
1107 if let Some(context) = loaded_context {
1108 message.loaded_context = context;
1109 }
1110 if let Some(git_checkpoint) = checkpoint {
1111 self.checkpoints_by_message.insert(
1112 id,
1113 ThreadCheckpoint {
1114 message_id: id,
1115 git_checkpoint,
1116 },
1117 );
1118 }
1119 self.touch_updated_at();
1120 cx.emit(ThreadEvent::MessageEdited(id));
1121 true
1122 }
1123
1124 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1125 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1126 return false;
1127 };
1128 self.messages.remove(index);
1129 self.touch_updated_at();
1130 cx.emit(ThreadEvent::MessageDeleted(id));
1131 true
1132 }
1133
1134 /// Returns the representation of this [`Thread`] in a textual form.
1135 ///
1136 /// This is the representation we use when attaching a thread as context to another thread.
1137 pub fn text(&self) -> String {
1138 let mut text = String::new();
1139
1140 for message in &self.messages {
1141 text.push_str(match message.role {
1142 language_model::Role::User => "User:",
1143 language_model::Role::Assistant => "Agent:",
1144 language_model::Role::System => "System:",
1145 });
1146 text.push('\n');
1147
1148 for segment in &message.segments {
1149 match segment {
1150 MessageSegment::Text(content) => text.push_str(content),
1151 MessageSegment::Thinking { text: content, .. } => {
1152 text.push_str(&format!("<think>{}</think>", content))
1153 }
1154 MessageSegment::RedactedThinking(_) => {}
1155 }
1156 }
1157 text.push('\n');
1158 }
1159
1160 text
1161 }
1162
1163 /// Serializes this thread into a format for storage or telemetry.
1164 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1165 let initial_project_snapshot = self.initial_project_snapshot.clone();
1166 cx.spawn(async move |this, cx| {
1167 let initial_project_snapshot = initial_project_snapshot.await;
1168 this.read_with(cx, |this, cx| SerializedThread {
1169 version: SerializedThread::VERSION.to_string(),
1170 summary: this.summary().or_default(),
1171 updated_at: this.updated_at(),
1172 messages: this
1173 .messages()
1174 .filter(|message| !message.ui_only)
1175 .map(|message| SerializedMessage {
1176 id: message.id,
1177 role: message.role,
1178 segments: message
1179 .segments
1180 .iter()
1181 .map(|segment| match segment {
1182 MessageSegment::Text(text) => {
1183 SerializedMessageSegment::Text { text: text.clone() }
1184 }
1185 MessageSegment::Thinking { text, signature } => {
1186 SerializedMessageSegment::Thinking {
1187 text: text.clone(),
1188 signature: signature.clone(),
1189 }
1190 }
1191 MessageSegment::RedactedThinking(data) => {
1192 SerializedMessageSegment::RedactedThinking {
1193 data: data.clone(),
1194 }
1195 }
1196 })
1197 .collect(),
1198 tool_uses: this
1199 .tool_uses_for_message(message.id, cx)
1200 .into_iter()
1201 .map(|tool_use| SerializedToolUse {
1202 id: tool_use.id,
1203 name: tool_use.name,
1204 input: tool_use.input,
1205 })
1206 .collect(),
1207 tool_results: this
1208 .tool_results_for_message(message.id)
1209 .into_iter()
1210 .map(|tool_result| SerializedToolResult {
1211 tool_use_id: tool_result.tool_use_id.clone(),
1212 is_error: tool_result.is_error,
1213 content: tool_result.content.clone(),
1214 output: tool_result.output.clone(),
1215 })
1216 .collect(),
1217 context: message.loaded_context.text.clone(),
1218 creases: message
1219 .creases
1220 .iter()
1221 .map(|crease| SerializedCrease {
1222 start: crease.range.start,
1223 end: crease.range.end,
1224 icon_path: crease.icon_path.clone(),
1225 label: crease.label.clone(),
1226 })
1227 .collect(),
1228 is_hidden: message.is_hidden,
1229 })
1230 .collect(),
1231 initial_project_snapshot,
1232 cumulative_token_usage: this.cumulative_token_usage,
1233 request_token_usage: this.request_token_usage.clone(),
1234 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1235 exceeded_window_error: this.exceeded_window_error.clone(),
1236 model: this
1237 .configured_model
1238 .as_ref()
1239 .map(|model| SerializedLanguageModel {
1240 provider: model.provider.id().0.to_string(),
1241 model: model.model.id().0.to_string(),
1242 }),
1243 completion_mode: Some(this.completion_mode),
1244 tool_use_limit_reached: this.tool_use_limit_reached,
1245 profile: Some(this.profile.id().clone()),
1246 })
1247 })
1248 }
1249
1250 pub fn remaining_turns(&self) -> u32 {
1251 self.remaining_turns
1252 }
1253
1254 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1255 self.remaining_turns = remaining_turns;
1256 }
1257
1258 pub fn send_to_model(
1259 &mut self,
1260 model: Arc<dyn LanguageModel>,
1261 intent: CompletionIntent,
1262 window: Option<AnyWindowHandle>,
1263 cx: &mut Context<Self>,
1264 ) {
1265 if self.remaining_turns == 0 {
1266 return;
1267 }
1268
1269 self.remaining_turns -= 1;
1270
1271 self.flush_notifications(model.clone(), intent, cx);
1272
1273 let _checkpoint = self.finalize_pending_checkpoint(cx);
1274 self.stream_completion(
1275 self.to_completion_request(model.clone(), intent, cx),
1276 model,
1277 intent,
1278 window,
1279 cx,
1280 );
1281 }
1282
1283 pub fn retry_last_completion(
1284 &mut self,
1285 window: Option<AnyWindowHandle>,
1286 cx: &mut Context<Self>,
1287 ) {
1288 // Clear any existing error state
1289 self.retry_state = None;
1290
1291 // Use the last error context if available, otherwise fall back to configured model
1292 let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() {
1293 (model, intent)
1294 } else if let Some(configured_model) = self.configured_model.as_ref() {
1295 let model = configured_model.model.clone();
1296 let intent = if self.has_pending_tool_uses() {
1297 CompletionIntent::ToolResults
1298 } else {
1299 CompletionIntent::UserPrompt
1300 };
1301 (model, intent)
1302 } else if let Some(configured_model) = self.get_or_init_configured_model(cx) {
1303 let model = configured_model.model.clone();
1304 let intent = if self.has_pending_tool_uses() {
1305 CompletionIntent::ToolResults
1306 } else {
1307 CompletionIntent::UserPrompt
1308 };
1309 (model, intent)
1310 } else {
1311 return;
1312 };
1313
1314 self.send_to_model(model, intent, window, cx);
1315 }
1316
1317 pub fn enable_burn_mode_and_retry(
1318 &mut self,
1319 window: Option<AnyWindowHandle>,
1320 cx: &mut Context<Self>,
1321 ) {
1322 self.completion_mode = CompletionMode::Burn;
1323 cx.emit(ThreadEvent::ProfileChanged);
1324 self.retry_last_completion(window, cx);
1325 }
1326
1327 pub fn used_tools_since_last_user_message(&self) -> bool {
1328 for message in self.messages.iter().rev() {
1329 if self.tool_use.message_has_tool_results(message.id) {
1330 return true;
1331 } else if message.role == Role::User {
1332 return false;
1333 }
1334 }
1335
1336 false
1337 }
1338
1339 pub fn to_completion_request(
1340 &self,
1341 model: Arc<dyn LanguageModel>,
1342 intent: CompletionIntent,
1343 cx: &mut Context<Self>,
1344 ) -> LanguageModelRequest {
1345 let mut request = LanguageModelRequest {
1346 thread_id: Some(self.id.to_string()),
1347 prompt_id: Some(self.last_prompt_id.to_string()),
1348 intent: Some(intent),
1349 mode: None,
1350 messages: vec![],
1351 tools: Vec::new(),
1352 tool_choice: None,
1353 stop: Vec::new(),
1354 temperature: AgentSettings::temperature_for_model(&model, cx),
1355 thinking_allowed: true,
1356 };
1357
1358 let available_tools = self.available_tools(cx, model.clone());
1359 let available_tool_names = available_tools
1360 .iter()
1361 .map(|tool| tool.name.clone())
1362 .collect();
1363
1364 let model_context = &ModelContext {
1365 available_tools: available_tool_names,
1366 };
1367
1368 if let Some(project_context) = self.project_context.borrow().as_ref() {
1369 match self
1370 .prompt_builder
1371 .generate_assistant_system_prompt(project_context, model_context)
1372 {
1373 Err(err) => {
1374 let message = format!("{err:?}").into();
1375 log::error!("{message}");
1376 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1377 header: "Error generating system prompt".into(),
1378 message,
1379 }));
1380 }
1381 Ok(system_prompt) => {
1382 request.messages.push(LanguageModelRequestMessage {
1383 role: Role::System,
1384 content: vec![MessageContent::Text(system_prompt)],
1385 cache: true,
1386 });
1387 }
1388 }
1389 } else {
1390 let message = "Context for system prompt unexpectedly not ready.".into();
1391 log::error!("{message}");
1392 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1393 header: "Error generating system prompt".into(),
1394 message,
1395 }));
1396 }
1397
1398 let mut message_ix_to_cache = None;
1399 for message in &self.messages {
1400 // ui_only messages are for the UI only, not for the model
1401 if message.ui_only {
1402 continue;
1403 }
1404
1405 let mut request_message = LanguageModelRequestMessage {
1406 role: message.role,
1407 content: Vec::new(),
1408 cache: false,
1409 };
1410
1411 message
1412 .loaded_context
1413 .add_to_request_message(&mut request_message);
1414
1415 for segment in &message.segments {
1416 match segment {
1417 MessageSegment::Text(text) => {
1418 let text = text.trim_end();
1419 if !text.is_empty() {
1420 request_message
1421 .content
1422 .push(MessageContent::Text(text.into()));
1423 }
1424 }
1425 MessageSegment::Thinking { text, signature } => {
1426 if !text.is_empty() {
1427 request_message.content.push(MessageContent::Thinking {
1428 text: text.into(),
1429 signature: signature.clone(),
1430 });
1431 }
1432 }
1433 MessageSegment::RedactedThinking(data) => {
1434 request_message
1435 .content
1436 .push(MessageContent::RedactedThinking(data.clone()));
1437 }
1438 };
1439 }
1440
1441 let mut cache_message = true;
1442 let mut tool_results_message = LanguageModelRequestMessage {
1443 role: Role::User,
1444 content: Vec::new(),
1445 cache: false,
1446 };
1447 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1448 if let Some(tool_result) = tool_result {
1449 request_message
1450 .content
1451 .push(MessageContent::ToolUse(tool_use.clone()));
1452 tool_results_message
1453 .content
1454 .push(MessageContent::ToolResult(LanguageModelToolResult {
1455 tool_use_id: tool_use.id.clone(),
1456 tool_name: tool_result.tool_name.clone(),
1457 is_error: tool_result.is_error,
1458 content: if tool_result.content.is_empty() {
1459 // Surprisingly, the API fails if we return an empty string here.
1460 // It thinks we are sending a tool use without a tool result.
1461 "<Tool returned an empty string>".into()
1462 } else {
1463 tool_result.content.clone()
1464 },
1465 output: None,
1466 }));
1467 } else {
1468 cache_message = false;
1469 log::debug!(
1470 "skipped tool use {:?} because it is still pending",
1471 tool_use
1472 );
1473 }
1474 }
1475
1476 if cache_message {
1477 message_ix_to_cache = Some(request.messages.len());
1478 }
1479 request.messages.push(request_message);
1480
1481 if !tool_results_message.content.is_empty() {
1482 if cache_message {
1483 message_ix_to_cache = Some(request.messages.len());
1484 }
1485 request.messages.push(tool_results_message);
1486 }
1487 }
1488
1489 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1490 if let Some(message_ix_to_cache) = message_ix_to_cache {
1491 request.messages[message_ix_to_cache].cache = true;
1492 }
1493
1494 request.tools = available_tools;
1495 request.mode = if model.supports_burn_mode() {
1496 Some(self.completion_mode.into())
1497 } else {
1498 Some(CompletionMode::Normal.into())
1499 };
1500
1501 request
1502 }
1503
1504 fn to_summarize_request(
1505 &self,
1506 model: &Arc<dyn LanguageModel>,
1507 intent: CompletionIntent,
1508 added_user_message: String,
1509 cx: &App,
1510 ) -> LanguageModelRequest {
1511 let mut request = LanguageModelRequest {
1512 thread_id: None,
1513 prompt_id: None,
1514 intent: Some(intent),
1515 mode: None,
1516 messages: vec![],
1517 tools: Vec::new(),
1518 tool_choice: None,
1519 stop: Vec::new(),
1520 temperature: AgentSettings::temperature_for_model(model, cx),
1521 thinking_allowed: false,
1522 };
1523
1524 for message in &self.messages {
1525 let mut request_message = LanguageModelRequestMessage {
1526 role: message.role,
1527 content: Vec::new(),
1528 cache: false,
1529 };
1530
1531 for segment in &message.segments {
1532 match segment {
1533 MessageSegment::Text(text) => request_message
1534 .content
1535 .push(MessageContent::Text(text.clone())),
1536 MessageSegment::Thinking { .. } => {}
1537 MessageSegment::RedactedThinking(_) => {}
1538 }
1539 }
1540
1541 if request_message.content.is_empty() {
1542 continue;
1543 }
1544
1545 request.messages.push(request_message);
1546 }
1547
1548 request.messages.push(LanguageModelRequestMessage {
1549 role: Role::User,
1550 content: vec![MessageContent::Text(added_user_message)],
1551 cache: false,
1552 });
1553
1554 request
1555 }
1556
1557 /// Insert auto-generated notifications (if any) to the thread
1558 fn flush_notifications(
1559 &mut self,
1560 model: Arc<dyn LanguageModel>,
1561 intent: CompletionIntent,
1562 cx: &mut Context<Self>,
1563 ) {
1564 match intent {
1565 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1566 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1567 cx.emit(ThreadEvent::ToolFinished {
1568 tool_use_id: pending_tool_use.id.clone(),
1569 pending_tool_use: Some(pending_tool_use),
1570 });
1571 }
1572 }
1573 CompletionIntent::ThreadSummarization
1574 | CompletionIntent::ThreadContextSummarization
1575 | CompletionIntent::CreateFile
1576 | CompletionIntent::EditFile
1577 | CompletionIntent::InlineAssist
1578 | CompletionIntent::TerminalInlineAssist
1579 | CompletionIntent::GenerateGitCommitMessage => {}
1580 };
1581 }
1582
1583 fn attach_tracked_files_state(
1584 &mut self,
1585 model: Arc<dyn LanguageModel>,
1586 cx: &mut App,
1587 ) -> Option<PendingToolUse> {
1588 // Represent notification as a simulated `project_notifications` tool call
1589 let tool_name = Arc::from("project_notifications");
1590 let tool = self.tools.read(cx).tool(&tool_name, cx)?;
1591
1592 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1593 return None;
1594 }
1595
1596 if self
1597 .action_log
1598 .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
1599 {
1600 return None;
1601 }
1602
1603 let input = serde_json::json!({});
1604 let request = Arc::new(LanguageModelRequest::default()); // unused
1605 let window = None;
1606 let tool_result = tool.run(
1607 input,
1608 request,
1609 self.project.clone(),
1610 self.action_log.clone(),
1611 model.clone(),
1612 window,
1613 cx,
1614 );
1615
1616 let tool_use_id =
1617 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1618
1619 let tool_use = LanguageModelToolUse {
1620 id: tool_use_id.clone(),
1621 name: tool_name.clone(),
1622 raw_input: "{}".to_string(),
1623 input: serde_json::json!({}),
1624 is_input_complete: true,
1625 };
1626
1627 let tool_output = cx.background_executor().block(tool_result.output);
1628
1629 // Attach a project_notification tool call to the latest existing
1630 // Assistant message. We cannot create a new Assistant message
1631 // because thinking models require a `thinking` block that we
1632 // cannot mock. We cannot send a notification as a normal
1633 // (non-tool-use) User message because this distracts Agent
1634 // too much.
1635 let tool_message_id = self
1636 .messages
1637 .iter()
1638 .enumerate()
1639 .rfind(|(_, message)| message.role == Role::Assistant)
1640 .map(|(_, message)| message.id)?;
1641
1642 let tool_use_metadata = ToolUseMetadata {
1643 model: model.clone(),
1644 thread_id: self.id.clone(),
1645 prompt_id: self.last_prompt_id.clone(),
1646 };
1647
1648 self.tool_use
1649 .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1650
1651 self.tool_use.insert_tool_output(
1652 tool_use_id.clone(),
1653 tool_name,
1654 tool_output,
1655 self.configured_model.as_ref(),
1656 self.completion_mode,
1657 )
1658 }
1659
1660 pub fn stream_completion(
1661 &mut self,
1662 request: LanguageModelRequest,
1663 model: Arc<dyn LanguageModel>,
1664 intent: CompletionIntent,
1665 window: Option<AnyWindowHandle>,
1666 cx: &mut Context<Self>,
1667 ) {
1668 self.tool_use_limit_reached = false;
1669
1670 let pending_completion_id = post_inc(&mut self.completion_count);
1671 let mut request_callback_parameters = if self.request_callback.is_some() {
1672 Some((request.clone(), Vec::new()))
1673 } else {
1674 None
1675 };
1676 let prompt_id = self.last_prompt_id.clone();
1677 let tool_use_metadata = ToolUseMetadata {
1678 model: model.clone(),
1679 thread_id: self.id.clone(),
1680 prompt_id: prompt_id.clone(),
1681 };
1682
1683 let completion_mode = request
1684 .mode
1685 .unwrap_or(cloud_llm_client::CompletionMode::Normal);
1686
1687 self.last_received_chunk_at = Some(Instant::now());
1688
1689 let task = cx.spawn(async move |thread, cx| {
1690 let stream_completion_future = model.stream_completion(request, cx);
1691 let initial_token_usage =
1692 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1693 let stream_completion = async {
1694 let mut events = stream_completion_future.await?;
1695
1696 let mut stop_reason = StopReason::EndTurn;
1697 let mut current_token_usage = TokenUsage::default();
1698
1699 thread
1700 .update(cx, |_thread, cx| {
1701 cx.emit(ThreadEvent::NewRequest);
1702 })
1703 .ok();
1704
1705 let mut request_assistant_message_id = None;
1706
1707 while let Some(event) = events.next().await {
1708 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1709 response_events
1710 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1711 }
1712
1713 thread.update(cx, |thread, cx| {
1714 match event? {
1715 LanguageModelCompletionEvent::StartMessage { .. } => {
1716 request_assistant_message_id =
1717 Some(thread.insert_assistant_message(
1718 vec![MessageSegment::Text(String::new())],
1719 cx,
1720 ));
1721 }
1722 LanguageModelCompletionEvent::Stop(reason) => {
1723 stop_reason = reason;
1724 }
1725 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1726 thread.update_token_usage_at_last_message(token_usage);
1727 thread.cumulative_token_usage = thread.cumulative_token_usage
1728 + token_usage
1729 - current_token_usage;
1730 current_token_usage = token_usage;
1731 }
1732 LanguageModelCompletionEvent::Text(chunk) => {
1733 thread.received_chunk();
1734
1735 cx.emit(ThreadEvent::ReceivedTextChunk);
1736 if let Some(last_message) = thread.messages.last_mut() {
1737 if last_message.role == Role::Assistant
1738 && !thread.tool_use.has_tool_results(last_message.id)
1739 {
1740 last_message.push_text(&chunk);
1741 cx.emit(ThreadEvent::StreamedAssistantText(
1742 last_message.id,
1743 chunk,
1744 ));
1745 } else {
1746 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1747 // of a new Assistant response.
1748 //
1749 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1750 // will result in duplicating the text of the chunk in the rendered Markdown.
1751 request_assistant_message_id =
1752 Some(thread.insert_assistant_message(
1753 vec![MessageSegment::Text(chunk.to_string())],
1754 cx,
1755 ));
1756 };
1757 }
1758 }
1759 LanguageModelCompletionEvent::Thinking {
1760 text: chunk,
1761 signature,
1762 } => {
1763 thread.received_chunk();
1764
1765 if let Some(last_message) = thread.messages.last_mut() {
1766 if last_message.role == Role::Assistant
1767 && !thread.tool_use.has_tool_results(last_message.id)
1768 {
1769 last_message.push_thinking(&chunk, signature);
1770 cx.emit(ThreadEvent::StreamedAssistantThinking(
1771 last_message.id,
1772 chunk,
1773 ));
1774 } else {
1775 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1776 // of a new Assistant response.
1777 //
1778 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1779 // will result in duplicating the text of the chunk in the rendered Markdown.
1780 request_assistant_message_id =
1781 Some(thread.insert_assistant_message(
1782 vec![MessageSegment::Thinking {
1783 text: chunk.to_string(),
1784 signature,
1785 }],
1786 cx,
1787 ));
1788 };
1789 }
1790 }
1791 LanguageModelCompletionEvent::RedactedThinking { data } => {
1792 thread.received_chunk();
1793
1794 if let Some(last_message) = thread.messages.last_mut() {
1795 if last_message.role == Role::Assistant
1796 && !thread.tool_use.has_tool_results(last_message.id)
1797 {
1798 last_message.push_redacted_thinking(data);
1799 } else {
1800 request_assistant_message_id =
1801 Some(thread.insert_assistant_message(
1802 vec![MessageSegment::RedactedThinking(data)],
1803 cx,
1804 ));
1805 };
1806 }
1807 }
1808 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1809 let last_assistant_message_id = request_assistant_message_id
1810 .unwrap_or_else(|| {
1811 let new_assistant_message_id =
1812 thread.insert_assistant_message(vec![], cx);
1813 request_assistant_message_id =
1814 Some(new_assistant_message_id);
1815 new_assistant_message_id
1816 });
1817
1818 let tool_use_id = tool_use.id.clone();
1819 let streamed_input = if tool_use.is_input_complete {
1820 None
1821 } else {
1822 Some(tool_use.input.clone())
1823 };
1824
1825 let ui_text = thread.tool_use.request_tool_use(
1826 last_assistant_message_id,
1827 tool_use,
1828 tool_use_metadata.clone(),
1829 cx,
1830 );
1831
1832 if let Some(input) = streamed_input {
1833 cx.emit(ThreadEvent::StreamedToolUse {
1834 tool_use_id,
1835 ui_text,
1836 input,
1837 });
1838 }
1839 }
1840 LanguageModelCompletionEvent::ToolUseJsonParseError {
1841 id,
1842 tool_name,
1843 raw_input: invalid_input_json,
1844 json_parse_error,
1845 } => {
1846 thread.receive_invalid_tool_json(
1847 id,
1848 tool_name,
1849 invalid_input_json,
1850 json_parse_error,
1851 window,
1852 cx,
1853 );
1854 }
1855 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1856 if let Some(completion) = thread
1857 .pending_completions
1858 .iter_mut()
1859 .find(|completion| completion.id == pending_completion_id)
1860 {
1861 match status_update {
1862 CompletionRequestStatus::Queued { position } => {
1863 completion.queue_state =
1864 QueueState::Queued { position };
1865 }
1866 CompletionRequestStatus::Started => {
1867 completion.queue_state = QueueState::Started;
1868 }
1869 CompletionRequestStatus::Failed {
1870 code,
1871 message,
1872 request_id: _,
1873 retry_after,
1874 } => {
1875 return Err(
1876 LanguageModelCompletionError::from_cloud_failure(
1877 model.upstream_provider_name(),
1878 code,
1879 message,
1880 retry_after.map(Duration::from_secs_f64),
1881 ),
1882 );
1883 }
1884 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1885 thread.update_model_request_usage(
1886 amount as u32,
1887 limit,
1888 cx,
1889 );
1890 }
1891 CompletionRequestStatus::ToolUseLimitReached => {
1892 thread.tool_use_limit_reached = true;
1893 cx.emit(ThreadEvent::ToolUseLimitReached);
1894 }
1895 }
1896 }
1897 }
1898 }
1899
1900 thread.touch_updated_at();
1901 cx.emit(ThreadEvent::StreamedCompletion);
1902 cx.notify();
1903
1904 Ok(())
1905 })??;
1906
1907 smol::future::yield_now().await;
1908 }
1909
1910 thread.update(cx, |thread, cx| {
1911 thread.last_received_chunk_at = None;
1912 thread
1913 .pending_completions
1914 .retain(|completion| completion.id != pending_completion_id);
1915
1916 // If there is a response without tool use, summarize the message. Otherwise,
1917 // allow two tool uses before summarizing.
1918 if matches!(thread.summary, ThreadSummary::Pending)
1919 && thread.messages.len() >= 2
1920 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1921 {
1922 thread.summarize(cx);
1923 }
1924 })?;
1925
1926 anyhow::Ok(stop_reason)
1927 };
1928
1929 let result = stream_completion.await;
1930 let mut retry_scheduled = false;
1931
1932 thread
1933 .update(cx, |thread, cx| {
1934 thread.finalize_pending_checkpoint(cx);
1935 match result.as_ref() {
1936 Ok(stop_reason) => {
1937 match stop_reason {
1938 StopReason::ToolUse => {
1939 let tool_uses =
1940 thread.use_pending_tools(window, model.clone(), cx);
1941 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1942 }
1943 StopReason::EndTurn | StopReason::MaxTokens => {
1944 thread.project.update(cx, |project, cx| {
1945 project.set_agent_location(None, cx);
1946 });
1947 }
1948 StopReason::Refusal => {
1949 thread.project.update(cx, |project, cx| {
1950 project.set_agent_location(None, cx);
1951 });
1952
1953 // Remove the turn that was refused.
1954 //
1955 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1956 {
1957 let mut messages_to_remove = Vec::new();
1958
1959 for (ix, message) in
1960 thread.messages.iter().enumerate().rev()
1961 {
1962 messages_to_remove.push(message.id);
1963
1964 if message.role == Role::User {
1965 if ix == 0 {
1966 break;
1967 }
1968
1969 if let Some(prev_message) =
1970 thread.messages.get(ix - 1)
1971 && prev_message.role == Role::Assistant {
1972 break;
1973 }
1974 }
1975 }
1976
1977 for message_id in messages_to_remove {
1978 thread.delete_message(message_id, cx);
1979 }
1980 }
1981
1982 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1983 header: "Language model refusal".into(),
1984 message:
1985 "Model refused to generate content for safety reasons."
1986 .into(),
1987 }));
1988 }
1989 }
1990
1991 // We successfully completed, so cancel any remaining retries.
1992 thread.retry_state = None;
1993 }
1994 Err(error) => {
1995 thread.project.update(cx, |project, cx| {
1996 project.set_agent_location(None, cx);
1997 });
1998
1999 if error.is::<PaymentRequiredError>() {
2000 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
2001 } else if let Some(error) =
2002 error.downcast_ref::<ModelRequestLimitReachedError>()
2003 {
2004 cx.emit(ThreadEvent::ShowError(
2005 ThreadError::ModelRequestLimitReached { plan: error.plan },
2006 ));
2007 } else if let Some(completion_error) =
2008 error.downcast_ref::<LanguageModelCompletionError>()
2009 {
2010 match &completion_error {
2011 LanguageModelCompletionError::PromptTooLarge {
2012 tokens, ..
2013 } => {
2014 let tokens = tokens.unwrap_or_else(|| {
2015 // We didn't get an exact token count from the API, so fall back on our estimate.
2016 thread
2017 .total_token_usage()
2018 .map(|usage| usage.total)
2019 .unwrap_or(0)
2020 // We know the context window was exceeded in practice, so if our estimate was
2021 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
2022 .max(
2023 model
2024 .max_token_count_for_mode(completion_mode)
2025 .saturating_add(1),
2026 )
2027 });
2028 thread.exceeded_window_error = Some(ExceededWindowError {
2029 model_id: model.id(),
2030 token_count: tokens,
2031 });
2032 cx.notify();
2033 }
2034 _ => {
2035 if let Some(retry_strategy) =
2036 Thread::get_retry_strategy(completion_error)
2037 {
2038 log::info!(
2039 "Retrying with {:?} for language model completion error {:?}",
2040 retry_strategy,
2041 completion_error
2042 );
2043
2044 retry_scheduled = thread
2045 .handle_retryable_error_with_delay(
2046 completion_error,
2047 Some(retry_strategy),
2048 model.clone(),
2049 intent,
2050 window,
2051 cx,
2052 );
2053 }
2054 }
2055 }
2056 }
2057
2058 if !retry_scheduled {
2059 thread.cancel_last_completion(window, cx);
2060 }
2061 }
2062 }
2063
2064 if !retry_scheduled {
2065 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2066 }
2067
2068 if let Some((request_callback, (request, response_events))) = thread
2069 .request_callback
2070 .as_mut()
2071 .zip(request_callback_parameters.as_ref())
2072 {
2073 request_callback(request, response_events);
2074 }
2075
2076 if let Ok(initial_usage) = initial_token_usage {
2077 let usage = thread.cumulative_token_usage - initial_usage;
2078
2079 telemetry::event!(
2080 "Assistant Thread Completion",
2081 thread_id = thread.id().to_string(),
2082 prompt_id = prompt_id,
2083 model = model.telemetry_id(),
2084 model_provider = model.provider_id().to_string(),
2085 input_tokens = usage.input_tokens,
2086 output_tokens = usage.output_tokens,
2087 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2088 cache_read_input_tokens = usage.cache_read_input_tokens,
2089 );
2090 }
2091 })
2092 .ok();
2093 });
2094
2095 self.pending_completions.push(PendingCompletion {
2096 id: pending_completion_id,
2097 queue_state: QueueState::Sending,
2098 _task: task,
2099 });
2100 }
2101
2102 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2103 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2104 println!("No thread summary model");
2105 return;
2106 };
2107
2108 if !model.provider.is_authenticated(cx) {
2109 return;
2110 }
2111
2112 let request = self.to_summarize_request(
2113 &model.model,
2114 CompletionIntent::ThreadSummarization,
2115 SUMMARIZE_THREAD_PROMPT.into(),
2116 cx,
2117 );
2118
2119 self.summary = ThreadSummary::Generating;
2120
2121 self.pending_summary = cx.spawn(async move |this, cx| {
2122 let result = async {
2123 let mut messages = model.model.stream_completion(request, cx).await?;
2124
2125 let mut new_summary = String::new();
2126 while let Some(event) = messages.next().await {
2127 let Ok(event) = event else {
2128 continue;
2129 };
2130 let text = match event {
2131 LanguageModelCompletionEvent::Text(text) => text,
2132 LanguageModelCompletionEvent::StatusUpdate(
2133 CompletionRequestStatus::UsageUpdated { amount, limit },
2134 ) => {
2135 this.update(cx, |thread, cx| {
2136 thread.update_model_request_usage(amount as u32, limit, cx);
2137 })?;
2138 continue;
2139 }
2140 _ => continue,
2141 };
2142
2143 let mut lines = text.lines();
2144 new_summary.extend(lines.next());
2145
2146 // Stop if the LLM generated multiple lines.
2147 if lines.next().is_some() {
2148 break;
2149 }
2150 }
2151
2152 anyhow::Ok(new_summary)
2153 }
2154 .await;
2155
2156 this.update(cx, |this, cx| {
2157 match result {
2158 Ok(new_summary) => {
2159 if new_summary.is_empty() {
2160 this.summary = ThreadSummary::Error;
2161 } else {
2162 this.summary = ThreadSummary::Ready(new_summary.into());
2163 }
2164 }
2165 Err(err) => {
2166 this.summary = ThreadSummary::Error;
2167 log::error!("Failed to generate thread summary: {}", err);
2168 }
2169 }
2170 cx.emit(ThreadEvent::SummaryGenerated);
2171 })
2172 .log_err()?;
2173
2174 Some(())
2175 });
2176 }
2177
2178 fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2179 use LanguageModelCompletionError::*;
2180
2181 // General strategy here:
2182 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2183 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2184 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2185 match error {
2186 HttpResponseError {
2187 status_code: StatusCode::TOO_MANY_REQUESTS,
2188 ..
2189 } => Some(RetryStrategy::ExponentialBackoff {
2190 initial_delay: BASE_RETRY_DELAY,
2191 max_attempts: MAX_RETRY_ATTEMPTS,
2192 }),
2193 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2194 Some(RetryStrategy::Fixed {
2195 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2196 max_attempts: MAX_RETRY_ATTEMPTS,
2197 })
2198 }
2199 UpstreamProviderError {
2200 status,
2201 retry_after,
2202 ..
2203 } => match *status {
2204 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2205 Some(RetryStrategy::Fixed {
2206 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2207 max_attempts: MAX_RETRY_ATTEMPTS,
2208 })
2209 }
2210 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2211 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2212 // Internal Server Error could be anything, retry up to 3 times.
2213 max_attempts: 3,
2214 }),
2215 status => {
2216 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2217 // but we frequently get them in practice. See https://http.dev/529
2218 if status.as_u16() == 529 {
2219 Some(RetryStrategy::Fixed {
2220 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2221 max_attempts: MAX_RETRY_ATTEMPTS,
2222 })
2223 } else {
2224 Some(RetryStrategy::Fixed {
2225 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2226 max_attempts: 2,
2227 })
2228 }
2229 }
2230 },
2231 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2232 delay: BASE_RETRY_DELAY,
2233 max_attempts: 3,
2234 }),
2235 ApiReadResponseError { .. }
2236 | HttpSend { .. }
2237 | DeserializeResponse { .. }
2238 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2239 delay: BASE_RETRY_DELAY,
2240 max_attempts: 3,
2241 }),
2242 // Retrying these errors definitely shouldn't help.
2243 HttpResponseError {
2244 status_code:
2245 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2246 ..
2247 }
2248 | AuthenticationError { .. }
2249 | PermissionError { .. }
2250 | NoApiKey { .. }
2251 | ApiEndpointNotFound { .. }
2252 | PromptTooLarge { .. } => None,
2253 // These errors might be transient, so retry them
2254 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2255 delay: BASE_RETRY_DELAY,
2256 max_attempts: 1,
2257 }),
2258 // Retry all other 4xx and 5xx errors once.
2259 HttpResponseError { status_code, .. }
2260 if status_code.is_client_error() || status_code.is_server_error() =>
2261 {
2262 Some(RetryStrategy::Fixed {
2263 delay: BASE_RETRY_DELAY,
2264 max_attempts: 3,
2265 })
2266 }
2267 Other(err)
2268 if err.is::<PaymentRequiredError>()
2269 || err.is::<ModelRequestLimitReachedError>() =>
2270 {
2271 // Retrying won't help for Payment Required or Model Request Limit errors (where
2272 // the user must upgrade to usage-based billing to get more requests, or else wait
2273 // for a significant amount of time for the request limit to reset).
2274 None
2275 }
2276 // Conservatively assume that any other errors are non-retryable
2277 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2278 delay: BASE_RETRY_DELAY,
2279 max_attempts: 2,
2280 }),
2281 }
2282 }
2283
2284 fn handle_retryable_error_with_delay(
2285 &mut self,
2286 error: &LanguageModelCompletionError,
2287 strategy: Option<RetryStrategy>,
2288 model: Arc<dyn LanguageModel>,
2289 intent: CompletionIntent,
2290 window: Option<AnyWindowHandle>,
2291 cx: &mut Context<Self>,
2292 ) -> bool {
2293 // Store context for the Retry button
2294 self.last_error_context = Some((model.clone(), intent));
2295
2296 // Only auto-retry if Burn Mode is enabled
2297 if self.completion_mode != CompletionMode::Burn {
2298 // Show error with retry options
2299 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2300 message: format!(
2301 "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
2302 error
2303 )
2304 .into(),
2305 can_enable_burn_mode: true,
2306 }));
2307 return false;
2308 }
2309
2310 let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2311 return false;
2312 };
2313
2314 let max_attempts = match &strategy {
2315 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2316 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2317 };
2318
2319 let retry_state = self.retry_state.get_or_insert(RetryState {
2320 attempt: 0,
2321 max_attempts,
2322 intent,
2323 });
2324
2325 retry_state.attempt += 1;
2326 let attempt = retry_state.attempt;
2327 let max_attempts = retry_state.max_attempts;
2328 let intent = retry_state.intent;
2329
2330 if attempt <= max_attempts {
2331 let delay = match &strategy {
2332 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2333 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2334 Duration::from_secs(delay_secs)
2335 }
2336 RetryStrategy::Fixed { delay, .. } => *delay,
2337 };
2338
2339 // Add a transient message to inform the user
2340 let delay_secs = delay.as_secs();
2341 let retry_message = if max_attempts == 1 {
2342 format!("{error}. Retrying in {delay_secs} seconds...")
2343 } else {
2344 format!(
2345 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2346 in {delay_secs} seconds..."
2347 )
2348 };
2349 log::warn!(
2350 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2351 in {delay_secs} seconds: {error:?}",
2352 );
2353
2354 // Add a UI-only message instead of a regular message
2355 let id = self.next_message_id.post_inc();
2356 self.messages.push(Message {
2357 id,
2358 role: Role::System,
2359 segments: vec![MessageSegment::Text(retry_message)],
2360 loaded_context: LoadedContext::default(),
2361 creases: Vec::new(),
2362 is_hidden: false,
2363 ui_only: true,
2364 });
2365 cx.emit(ThreadEvent::MessageAdded(id));
2366
2367 // Schedule the retry
2368 let thread_handle = cx.entity().downgrade();
2369
2370 cx.spawn(async move |_thread, cx| {
2371 cx.background_executor().timer(delay).await;
2372
2373 thread_handle
2374 .update(cx, |thread, cx| {
2375 // Retry the completion
2376 thread.send_to_model(model, intent, window, cx);
2377 })
2378 .log_err();
2379 })
2380 .detach();
2381
2382 true
2383 } else {
2384 // Max retries exceeded
2385 self.retry_state = None;
2386
2387 // Stop generating since we're giving up on retrying.
2388 self.pending_completions.clear();
2389
2390 // Show error alongside a Retry button, but no
2391 // Enable Burn Mode button (since it's already enabled)
2392 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2393 message: format!("Failed after retrying: {}", error).into(),
2394 can_enable_burn_mode: false,
2395 }));
2396
2397 false
2398 }
2399 }
2400
2401 pub fn start_generating_detailed_summary_if_needed(
2402 &mut self,
2403 thread_store: WeakEntity<ThreadStore>,
2404 cx: &mut Context<Self>,
2405 ) {
2406 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2407 return;
2408 };
2409
2410 match &*self.detailed_summary_rx.borrow() {
2411 DetailedSummaryState::Generating { message_id, .. }
2412 | DetailedSummaryState::Generated { message_id, .. }
2413 if *message_id == last_message_id =>
2414 {
2415 // Already up-to-date
2416 return;
2417 }
2418 _ => {}
2419 }
2420
2421 let Some(ConfiguredModel { model, provider }) =
2422 LanguageModelRegistry::read_global(cx).thread_summary_model()
2423 else {
2424 return;
2425 };
2426
2427 if !provider.is_authenticated(cx) {
2428 return;
2429 }
2430
2431 let request = self.to_summarize_request(
2432 &model,
2433 CompletionIntent::ThreadContextSummarization,
2434 SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
2435 cx,
2436 );
2437
2438 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2439 message_id: last_message_id,
2440 };
2441
2442 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2443 // be better to allow the old task to complete, but this would require logic for choosing
2444 // which result to prefer (the old task could complete after the new one, resulting in a
2445 // stale summary).
2446 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2447 let stream = model.stream_completion_text(request, cx);
2448 let Some(mut messages) = stream.await.log_err() else {
2449 thread
2450 .update(cx, |thread, _cx| {
2451 *thread.detailed_summary_tx.borrow_mut() =
2452 DetailedSummaryState::NotGenerated;
2453 })
2454 .ok()?;
2455 return None;
2456 };
2457
2458 let mut new_detailed_summary = String::new();
2459
2460 while let Some(chunk) = messages.stream.next().await {
2461 if let Some(chunk) = chunk.log_err() {
2462 new_detailed_summary.push_str(&chunk);
2463 }
2464 }
2465
2466 thread
2467 .update(cx, |thread, _cx| {
2468 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2469 text: new_detailed_summary.into(),
2470 message_id: last_message_id,
2471 };
2472 })
2473 .ok()?;
2474
2475 // Save thread so its summary can be reused later
2476 if let Some(thread) = thread.upgrade()
2477 && let Ok(Ok(save_task)) = cx.update(|cx| {
2478 thread_store
2479 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2480 })
2481 {
2482 save_task.await.log_err();
2483 }
2484
2485 Some(())
2486 });
2487 }
2488
2489 pub async fn wait_for_detailed_summary_or_text(
2490 this: &Entity<Self>,
2491 cx: &mut AsyncApp,
2492 ) -> Option<SharedString> {
2493 let mut detailed_summary_rx = this
2494 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2495 .ok()?;
2496 loop {
2497 match detailed_summary_rx.recv().await? {
2498 DetailedSummaryState::Generating { .. } => {}
2499 DetailedSummaryState::NotGenerated => {
2500 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2501 }
2502 DetailedSummaryState::Generated { text, .. } => return Some(text),
2503 }
2504 }
2505 }
2506
2507 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2508 self.detailed_summary_rx
2509 .borrow()
2510 .text()
2511 .unwrap_or_else(|| self.text().into())
2512 }
2513
2514 pub fn is_generating_detailed_summary(&self) -> bool {
2515 matches!(
2516 &*self.detailed_summary_rx.borrow(),
2517 DetailedSummaryState::Generating { .. }
2518 )
2519 }
2520
2521 pub fn use_pending_tools(
2522 &mut self,
2523 window: Option<AnyWindowHandle>,
2524 model: Arc<dyn LanguageModel>,
2525 cx: &mut Context<Self>,
2526 ) -> Vec<PendingToolUse> {
2527 let request =
2528 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2529 let pending_tool_uses = self
2530 .tool_use
2531 .pending_tool_uses()
2532 .into_iter()
2533 .filter(|tool_use| tool_use.status.is_idle())
2534 .cloned()
2535 .collect::<Vec<_>>();
2536
2537 for tool_use in pending_tool_uses.iter() {
2538 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2539 }
2540
2541 pending_tool_uses
2542 }
2543
2544 fn use_pending_tool(
2545 &mut self,
2546 tool_use: PendingToolUse,
2547 request: Arc<LanguageModelRequest>,
2548 model: Arc<dyn LanguageModel>,
2549 window: Option<AnyWindowHandle>,
2550 cx: &mut Context<Self>,
2551 ) {
2552 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2553 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2554 };
2555
2556 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2557 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2558 }
2559
2560 if tool.needs_confirmation(&tool_use.input, &self.project, cx)
2561 && !AgentSettings::get_global(cx).always_allow_tool_actions
2562 {
2563 self.tool_use.confirm_tool_use(
2564 tool_use.id,
2565 tool_use.ui_text,
2566 tool_use.input,
2567 request,
2568 tool,
2569 );
2570 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2571 } else {
2572 self.run_tool(
2573 tool_use.id,
2574 tool_use.ui_text,
2575 tool_use.input,
2576 request,
2577 tool,
2578 model,
2579 window,
2580 cx,
2581 );
2582 }
2583 }
2584
2585 pub fn handle_hallucinated_tool_use(
2586 &mut self,
2587 tool_use_id: LanguageModelToolUseId,
2588 hallucinated_tool_name: Arc<str>,
2589 window: Option<AnyWindowHandle>,
2590 cx: &mut Context<Thread>,
2591 ) {
2592 let available_tools = self.profile.enabled_tools(cx);
2593
2594 let tool_list = available_tools
2595 .iter()
2596 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2597 .collect::<Vec<_>>()
2598 .join("\n");
2599
2600 let error_message = format!(
2601 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2602 hallucinated_tool_name, tool_list
2603 );
2604
2605 let pending_tool_use = self.tool_use.insert_tool_output(
2606 tool_use_id.clone(),
2607 hallucinated_tool_name,
2608 Err(anyhow!("Missing tool call: {error_message}")),
2609 self.configured_model.as_ref(),
2610 self.completion_mode,
2611 );
2612
2613 cx.emit(ThreadEvent::MissingToolUse {
2614 tool_use_id: tool_use_id.clone(),
2615 ui_text: error_message.into(),
2616 });
2617
2618 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2619 }
2620
2621 pub fn receive_invalid_tool_json(
2622 &mut self,
2623 tool_use_id: LanguageModelToolUseId,
2624 tool_name: Arc<str>,
2625 invalid_json: Arc<str>,
2626 error: String,
2627 window: Option<AnyWindowHandle>,
2628 cx: &mut Context<Thread>,
2629 ) {
2630 log::error!("The model returned invalid input JSON: {invalid_json}");
2631
2632 let pending_tool_use = self.tool_use.insert_tool_output(
2633 tool_use_id.clone(),
2634 tool_name,
2635 Err(anyhow!("Error parsing input JSON: {error}")),
2636 self.configured_model.as_ref(),
2637 self.completion_mode,
2638 );
2639 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2640 pending_tool_use.ui_text.clone()
2641 } else {
2642 log::error!(
2643 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2644 );
2645 format!("Unknown tool {}", tool_use_id).into()
2646 };
2647
2648 cx.emit(ThreadEvent::InvalidToolInput {
2649 tool_use_id: tool_use_id.clone(),
2650 ui_text,
2651 invalid_input_json: invalid_json,
2652 });
2653
2654 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2655 }
2656
2657 pub fn run_tool(
2658 &mut self,
2659 tool_use_id: LanguageModelToolUseId,
2660 ui_text: impl Into<SharedString>,
2661 input: serde_json::Value,
2662 request: Arc<LanguageModelRequest>,
2663 tool: Arc<dyn Tool>,
2664 model: Arc<dyn LanguageModel>,
2665 window: Option<AnyWindowHandle>,
2666 cx: &mut Context<Thread>,
2667 ) {
2668 let task =
2669 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2670 self.tool_use
2671 .run_pending_tool(tool_use_id, ui_text.into(), task);
2672 }
2673
2674 fn spawn_tool_use(
2675 &mut self,
2676 tool_use_id: LanguageModelToolUseId,
2677 request: Arc<LanguageModelRequest>,
2678 input: serde_json::Value,
2679 tool: Arc<dyn Tool>,
2680 model: Arc<dyn LanguageModel>,
2681 window: Option<AnyWindowHandle>,
2682 cx: &mut Context<Thread>,
2683 ) -> Task<()> {
2684 let tool_name: Arc<str> = tool.name().into();
2685
2686 let tool_result = tool.run(
2687 input,
2688 request,
2689 self.project.clone(),
2690 self.action_log.clone(),
2691 model,
2692 window,
2693 cx,
2694 );
2695
2696 // Store the card separately if it exists
2697 if let Some(card) = tool_result.card.clone() {
2698 self.tool_use
2699 .insert_tool_result_card(tool_use_id.clone(), card);
2700 }
2701
2702 cx.spawn({
2703 async move |thread: WeakEntity<Thread>, cx| {
2704 let output = tool_result.output.await;
2705
2706 thread
2707 .update(cx, |thread, cx| {
2708 let pending_tool_use = thread.tool_use.insert_tool_output(
2709 tool_use_id.clone(),
2710 tool_name,
2711 output,
2712 thread.configured_model.as_ref(),
2713 thread.completion_mode,
2714 );
2715 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2716 })
2717 .ok();
2718 }
2719 })
2720 }
2721
2722 fn tool_finished(
2723 &mut self,
2724 tool_use_id: LanguageModelToolUseId,
2725 pending_tool_use: Option<PendingToolUse>,
2726 canceled: bool,
2727 window: Option<AnyWindowHandle>,
2728 cx: &mut Context<Self>,
2729 ) {
2730 if self.all_tools_finished()
2731 && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
2732 && !canceled
2733 {
2734 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2735 }
2736
2737 cx.emit(ThreadEvent::ToolFinished {
2738 tool_use_id,
2739 pending_tool_use,
2740 });
2741 }
2742
2743 /// Cancels the last pending completion, if there are any pending.
2744 ///
2745 /// Returns whether a completion was canceled.
2746 pub fn cancel_last_completion(
2747 &mut self,
2748 window: Option<AnyWindowHandle>,
2749 cx: &mut Context<Self>,
2750 ) -> bool {
2751 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2752
2753 self.retry_state = None;
2754
2755 for pending_tool_use in self.tool_use.cancel_pending() {
2756 canceled = true;
2757 self.tool_finished(
2758 pending_tool_use.id.clone(),
2759 Some(pending_tool_use),
2760 true,
2761 window,
2762 cx,
2763 );
2764 }
2765
2766 if canceled {
2767 cx.emit(ThreadEvent::CompletionCanceled);
2768
2769 // When canceled, we always want to insert the checkpoint.
2770 // (We skip over finalize_pending_checkpoint, because it
2771 // would conclude we didn't have anything to insert here.)
2772 if let Some(checkpoint) = self.pending_checkpoint.take() {
2773 self.insert_checkpoint(checkpoint, cx);
2774 }
2775 } else {
2776 self.finalize_pending_checkpoint(cx);
2777 }
2778
2779 canceled
2780 }
2781
2782 /// Signals that any in-progress editing should be canceled.
2783 ///
2784 /// This method is used to notify listeners (like ActiveThread) that
2785 /// they should cancel any editing operations.
2786 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2787 cx.emit(ThreadEvent::CancelEditing);
2788 }
2789
2790 pub fn feedback(&self) -> Option<ThreadFeedback> {
2791 self.feedback
2792 }
2793
2794 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2795 self.message_feedback.get(&message_id).copied()
2796 }
2797
2798 pub fn report_message_feedback(
2799 &mut self,
2800 message_id: MessageId,
2801 feedback: ThreadFeedback,
2802 cx: &mut Context<Self>,
2803 ) -> Task<Result<()>> {
2804 if self.message_feedback.get(&message_id) == Some(&feedback) {
2805 return Task::ready(Ok(()));
2806 }
2807
2808 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2809 let serialized_thread = self.serialize(cx);
2810 let thread_id = self.id().clone();
2811 let client = self.project.read(cx).client();
2812
2813 let enabled_tool_names: Vec<String> = self
2814 .profile
2815 .enabled_tools(cx)
2816 .iter()
2817 .map(|(name, _)| name.clone().into())
2818 .collect();
2819
2820 self.message_feedback.insert(message_id, feedback);
2821
2822 cx.notify();
2823
2824 let message_content = self
2825 .message(message_id)
2826 .map(|msg| msg.to_string())
2827 .unwrap_or_default();
2828
2829 cx.background_spawn(async move {
2830 let final_project_snapshot = final_project_snapshot.await;
2831 let serialized_thread = serialized_thread.await?;
2832 let thread_data =
2833 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2834
2835 let rating = match feedback {
2836 ThreadFeedback::Positive => "positive",
2837 ThreadFeedback::Negative => "negative",
2838 };
2839 telemetry::event!(
2840 "Assistant Thread Rated",
2841 rating,
2842 thread_id,
2843 enabled_tool_names,
2844 message_id = message_id.0,
2845 message_content,
2846 thread_data,
2847 final_project_snapshot
2848 );
2849 client.telemetry().flush_events().await;
2850
2851 Ok(())
2852 })
2853 }
2854
2855 pub fn report_feedback(
2856 &mut self,
2857 feedback: ThreadFeedback,
2858 cx: &mut Context<Self>,
2859 ) -> Task<Result<()>> {
2860 let last_assistant_message_id = self
2861 .messages
2862 .iter()
2863 .rev()
2864 .find(|msg| msg.role == Role::Assistant)
2865 .map(|msg| msg.id);
2866
2867 if let Some(message_id) = last_assistant_message_id {
2868 self.report_message_feedback(message_id, feedback, cx)
2869 } else {
2870 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2871 let serialized_thread = self.serialize(cx);
2872 let thread_id = self.id().clone();
2873 let client = self.project.read(cx).client();
2874 self.feedback = Some(feedback);
2875 cx.notify();
2876
2877 cx.background_spawn(async move {
2878 let final_project_snapshot = final_project_snapshot.await;
2879 let serialized_thread = serialized_thread.await?;
2880 let thread_data = serde_json::to_value(serialized_thread)
2881 .unwrap_or_else(|_| serde_json::Value::Null);
2882
2883 let rating = match feedback {
2884 ThreadFeedback::Positive => "positive",
2885 ThreadFeedback::Negative => "negative",
2886 };
2887 telemetry::event!(
2888 "Assistant Thread Rated",
2889 rating,
2890 thread_id,
2891 thread_data,
2892 final_project_snapshot
2893 );
2894 client.telemetry().flush_events().await;
2895
2896 Ok(())
2897 })
2898 }
2899 }
2900
2901 /// Create a snapshot of the current project state including git information and unsaved buffers.
2902 fn project_snapshot(
2903 project: Entity<Project>,
2904 cx: &mut Context<Self>,
2905 ) -> Task<Arc<ProjectSnapshot>> {
2906 let git_store = project.read(cx).git_store().clone();
2907 let worktree_snapshots: Vec<_> = project
2908 .read(cx)
2909 .visible_worktrees(cx)
2910 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2911 .collect();
2912
2913 cx.spawn(async move |_, cx| {
2914 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2915
2916 let mut unsaved_buffers = Vec::new();
2917 cx.update(|app_cx| {
2918 let buffer_store = project.read(app_cx).buffer_store();
2919 for buffer_handle in buffer_store.read(app_cx).buffers() {
2920 let buffer = buffer_handle.read(app_cx);
2921 if buffer.is_dirty()
2922 && let Some(file) = buffer.file()
2923 {
2924 let path = file.path().to_string_lossy().to_string();
2925 unsaved_buffers.push(path);
2926 }
2927 }
2928 })
2929 .ok();
2930
2931 Arc::new(ProjectSnapshot {
2932 worktree_snapshots,
2933 unsaved_buffer_paths: unsaved_buffers,
2934 timestamp: Utc::now(),
2935 })
2936 })
2937 }
2938
2939 fn worktree_snapshot(
2940 worktree: Entity<project::Worktree>,
2941 git_store: Entity<GitStore>,
2942 cx: &App,
2943 ) -> Task<WorktreeSnapshot> {
2944 cx.spawn(async move |cx| {
2945 // Get worktree path and snapshot
2946 let worktree_info = cx.update(|app_cx| {
2947 let worktree = worktree.read(app_cx);
2948 let path = worktree.abs_path().to_string_lossy().to_string();
2949 let snapshot = worktree.snapshot();
2950 (path, snapshot)
2951 });
2952
2953 let Ok((worktree_path, _snapshot)) = worktree_info else {
2954 return WorktreeSnapshot {
2955 worktree_path: String::new(),
2956 git_state: None,
2957 };
2958 };
2959
2960 let git_state = git_store
2961 .update(cx, |git_store, cx| {
2962 git_store
2963 .repositories()
2964 .values()
2965 .find(|repo| {
2966 repo.read(cx)
2967 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2968 .is_some()
2969 })
2970 .cloned()
2971 })
2972 .ok()
2973 .flatten()
2974 .map(|repo| {
2975 repo.update(cx, |repo, _| {
2976 let current_branch =
2977 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2978 repo.send_job(None, |state, _| async move {
2979 let RepositoryState::Local { backend, .. } = state else {
2980 return GitState {
2981 remote_url: None,
2982 head_sha: None,
2983 current_branch,
2984 diff: None,
2985 };
2986 };
2987
2988 let remote_url = backend.remote_url("origin");
2989 let head_sha = backend.head_sha().await;
2990 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2991
2992 GitState {
2993 remote_url,
2994 head_sha,
2995 current_branch,
2996 diff,
2997 }
2998 })
2999 })
3000 });
3001
3002 let git_state = match git_state {
3003 Some(git_state) => match git_state.ok() {
3004 Some(git_state) => git_state.await.ok(),
3005 None => None,
3006 },
3007 None => None,
3008 };
3009
3010 WorktreeSnapshot {
3011 worktree_path,
3012 git_state,
3013 }
3014 })
3015 }
3016
3017 pub fn to_markdown(&self, cx: &App) -> Result<String> {
3018 let mut markdown = Vec::new();
3019
3020 let summary = self.summary().or_default();
3021 writeln!(markdown, "# {summary}\n")?;
3022
3023 for message in self.messages() {
3024 writeln!(
3025 markdown,
3026 "## {role}\n",
3027 role = match message.role {
3028 Role::User => "User",
3029 Role::Assistant => "Agent",
3030 Role::System => "System",
3031 }
3032 )?;
3033
3034 if !message.loaded_context.text.is_empty() {
3035 writeln!(markdown, "{}", message.loaded_context.text)?;
3036 }
3037
3038 if !message.loaded_context.images.is_empty() {
3039 writeln!(
3040 markdown,
3041 "\n{} images attached as context.\n",
3042 message.loaded_context.images.len()
3043 )?;
3044 }
3045
3046 for segment in &message.segments {
3047 match segment {
3048 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
3049 MessageSegment::Thinking { text, .. } => {
3050 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
3051 }
3052 MessageSegment::RedactedThinking(_) => {}
3053 }
3054 }
3055
3056 for tool_use in self.tool_uses_for_message(message.id, cx) {
3057 writeln!(
3058 markdown,
3059 "**Use Tool: {} ({})**",
3060 tool_use.name, tool_use.id
3061 )?;
3062 writeln!(markdown, "```json")?;
3063 writeln!(
3064 markdown,
3065 "{}",
3066 serde_json::to_string_pretty(&tool_use.input)?
3067 )?;
3068 writeln!(markdown, "```")?;
3069 }
3070
3071 for tool_result in self.tool_results_for_message(message.id) {
3072 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
3073 if tool_result.is_error {
3074 write!(markdown, " (Error)")?;
3075 }
3076
3077 writeln!(markdown, "**\n")?;
3078 match &tool_result.content {
3079 LanguageModelToolResultContent::Text(text) => {
3080 writeln!(markdown, "{text}")?;
3081 }
3082 LanguageModelToolResultContent::Image(image) => {
3083 writeln!(markdown, "", image.source)?;
3084 }
3085 }
3086
3087 if let Some(output) = tool_result.output.as_ref() {
3088 writeln!(
3089 markdown,
3090 "\n\nDebug Output:\n\n```json\n{}\n```\n",
3091 serde_json::to_string_pretty(output)?
3092 )?;
3093 }
3094 }
3095 }
3096
3097 Ok(String::from_utf8_lossy(&markdown).to_string())
3098 }
3099
3100 pub fn keep_edits_in_range(
3101 &mut self,
3102 buffer: Entity<language::Buffer>,
3103 buffer_range: Range<language::Anchor>,
3104 cx: &mut Context<Self>,
3105 ) {
3106 self.action_log.update(cx, |action_log, cx| {
3107 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3108 });
3109 }
3110
3111 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3112 self.action_log
3113 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3114 }
3115
3116 pub fn reject_edits_in_ranges(
3117 &mut self,
3118 buffer: Entity<language::Buffer>,
3119 buffer_ranges: Vec<Range<language::Anchor>>,
3120 cx: &mut Context<Self>,
3121 ) -> Task<Result<()>> {
3122 self.action_log.update(cx, |action_log, cx| {
3123 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3124 })
3125 }
3126
3127 pub fn action_log(&self) -> &Entity<ActionLog> {
3128 &self.action_log
3129 }
3130
3131 pub fn project(&self) -> &Entity<Project> {
3132 &self.project
3133 }
3134
3135 pub fn cumulative_token_usage(&self) -> TokenUsage {
3136 self.cumulative_token_usage
3137 }
3138
3139 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3140 let Some(model) = self.configured_model.as_ref() else {
3141 return TotalTokenUsage::default();
3142 };
3143
3144 let max = model
3145 .model
3146 .max_token_count_for_mode(self.completion_mode().into());
3147
3148 let index = self
3149 .messages
3150 .iter()
3151 .position(|msg| msg.id == message_id)
3152 .unwrap_or(0);
3153
3154 if index == 0 {
3155 return TotalTokenUsage { total: 0, max };
3156 }
3157
3158 let token_usage = &self
3159 .request_token_usage
3160 .get(index - 1)
3161 .cloned()
3162 .unwrap_or_default();
3163
3164 TotalTokenUsage {
3165 total: token_usage.total_tokens(),
3166 max,
3167 }
3168 }
3169
3170 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3171 let model = self.configured_model.as_ref()?;
3172
3173 let max = model
3174 .model
3175 .max_token_count_for_mode(self.completion_mode().into());
3176
3177 if let Some(exceeded_error) = &self.exceeded_window_error
3178 && model.model.id() == exceeded_error.model_id
3179 {
3180 return Some(TotalTokenUsage {
3181 total: exceeded_error.token_count,
3182 max,
3183 });
3184 }
3185
3186 let total = self
3187 .token_usage_at_last_message()
3188 .unwrap_or_default()
3189 .total_tokens();
3190
3191 Some(TotalTokenUsage { total, max })
3192 }
3193
3194 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3195 self.request_token_usage
3196 .get(self.messages.len().saturating_sub(1))
3197 .or_else(|| self.request_token_usage.last())
3198 .cloned()
3199 }
3200
3201 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3202 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3203 self.request_token_usage
3204 .resize(self.messages.len(), placeholder);
3205
3206 if let Some(last) = self.request_token_usage.last_mut() {
3207 *last = token_usage;
3208 }
3209 }
3210
3211 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3212 self.project
3213 .read(cx)
3214 .user_store()
3215 .update(cx, |user_store, cx| {
3216 user_store.update_model_request_usage(
3217 ModelRequestUsage(RequestUsage {
3218 amount: amount as i32,
3219 limit,
3220 }),
3221 cx,
3222 )
3223 });
3224 }
3225
3226 pub fn deny_tool_use(
3227 &mut self,
3228 tool_use_id: LanguageModelToolUseId,
3229 tool_name: Arc<str>,
3230 window: Option<AnyWindowHandle>,
3231 cx: &mut Context<Self>,
3232 ) {
3233 let err = Err(anyhow::anyhow!(
3234 "Permission to run tool action denied by user"
3235 ));
3236
3237 self.tool_use.insert_tool_output(
3238 tool_use_id.clone(),
3239 tool_name,
3240 err,
3241 self.configured_model.as_ref(),
3242 self.completion_mode,
3243 );
3244 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3245 }
3246}
3247
3248#[derive(Debug, Clone, Error)]
3249pub enum ThreadError {
3250 #[error("Payment required")]
3251 PaymentRequired,
3252 #[error("Model request limit reached")]
3253 ModelRequestLimitReached { plan: Plan },
3254 #[error("Message {header}: {message}")]
3255 Message {
3256 header: SharedString,
3257 message: SharedString,
3258 },
3259 #[error("Retryable error: {message}")]
3260 RetryableError {
3261 message: SharedString,
3262 can_enable_burn_mode: bool,
3263 },
3264}
3265
3266#[derive(Debug, Clone)]
3267pub enum ThreadEvent {
3268 ShowError(ThreadError),
3269 StreamedCompletion,
3270 ReceivedTextChunk,
3271 NewRequest,
3272 StreamedAssistantText(MessageId, String),
3273 StreamedAssistantThinking(MessageId, String),
3274 StreamedToolUse {
3275 tool_use_id: LanguageModelToolUseId,
3276 ui_text: Arc<str>,
3277 input: serde_json::Value,
3278 },
3279 MissingToolUse {
3280 tool_use_id: LanguageModelToolUseId,
3281 ui_text: Arc<str>,
3282 },
3283 InvalidToolInput {
3284 tool_use_id: LanguageModelToolUseId,
3285 ui_text: Arc<str>,
3286 invalid_input_json: Arc<str>,
3287 },
3288 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3289 MessageAdded(MessageId),
3290 MessageEdited(MessageId),
3291 MessageDeleted(MessageId),
3292 SummaryGenerated,
3293 SummaryChanged,
3294 UsePendingTools {
3295 tool_uses: Vec<PendingToolUse>,
3296 },
3297 ToolFinished {
3298 #[allow(unused)]
3299 tool_use_id: LanguageModelToolUseId,
3300 /// The pending tool use that corresponds to this tool.
3301 pending_tool_use: Option<PendingToolUse>,
3302 },
3303 CheckpointChanged,
3304 ToolConfirmationNeeded,
3305 ToolUseLimitReached,
3306 CancelEditing,
3307 CompletionCanceled,
3308 ProfileChanged,
3309}
3310
3311impl EventEmitter<ThreadEvent> for Thread {}
3312
3313struct PendingCompletion {
3314 id: usize,
3315 queue_state: QueueState,
3316 _task: Task<()>,
3317}
3318
3319#[cfg(test)]
3320mod tests {
3321 use super::*;
3322 use crate::{
3323 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3324 };
3325
3326 // Test-specific constants
3327 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3328 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3329 use assistant_tool::ToolRegistry;
3330 use assistant_tools;
3331 use futures::StreamExt;
3332 use futures::future::BoxFuture;
3333 use futures::stream::BoxStream;
3334 use gpui::TestAppContext;
3335 use http_client;
3336 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3337 use language_model::{
3338 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3339 LanguageModelProviderName, LanguageModelToolChoice,
3340 };
3341 use parking_lot::Mutex;
3342 use project::{FakeFs, Project};
3343 use prompt_store::PromptBuilder;
3344 use serde_json::json;
3345 use settings::{Settings, SettingsStore};
3346 use std::sync::Arc;
3347 use std::time::Duration;
3348 use theme::ThemeSettings;
3349 use util::path;
3350 use workspace::Workspace;
3351
3352 #[gpui::test]
3353 async fn test_message_with_context(cx: &mut TestAppContext) {
3354 init_test_settings(cx);
3355
3356 let project = create_test_project(
3357 cx,
3358 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3359 )
3360 .await;
3361
3362 let (_workspace, _thread_store, thread, context_store, model) =
3363 setup_test_environment(cx, project.clone()).await;
3364
3365 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3366 .await
3367 .unwrap();
3368
3369 let context =
3370 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3371 let loaded_context = cx
3372 .update(|cx| load_context(vec![context], &project, &None, cx))
3373 .await;
3374
3375 // Insert user message with context
3376 let message_id = thread.update(cx, |thread, cx| {
3377 thread.insert_user_message(
3378 "Please explain this code",
3379 loaded_context,
3380 None,
3381 Vec::new(),
3382 cx,
3383 )
3384 });
3385
3386 // Check content and context in message object
3387 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3388
3389 // Use different path format strings based on platform for the test
3390 #[cfg(windows)]
3391 let path_part = r"test\code.rs";
3392 #[cfg(not(windows))]
3393 let path_part = "test/code.rs";
3394
3395 let expected_context = format!(
3396 r#"
3397<context>
3398The following items were attached by the user. They are up-to-date and don't need to be re-read.
3399
3400<files>
3401```rs {path_part}
3402fn main() {{
3403 println!("Hello, world!");
3404}}
3405```
3406</files>
3407</context>
3408"#
3409 );
3410
3411 assert_eq!(message.role, Role::User);
3412 assert_eq!(message.segments.len(), 1);
3413 assert_eq!(
3414 message.segments[0],
3415 MessageSegment::Text("Please explain this code".to_string())
3416 );
3417 assert_eq!(message.loaded_context.text, expected_context);
3418
3419 // Check message in request
3420 let request = thread.update(cx, |thread, cx| {
3421 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3422 });
3423
3424 assert_eq!(request.messages.len(), 2);
3425 let expected_full_message = format!("{}Please explain this code", expected_context);
3426 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3427 }
3428
3429 #[gpui::test]
3430 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3431 init_test_settings(cx);
3432
3433 let project = create_test_project(
3434 cx,
3435 json!({
3436 "file1.rs": "fn function1() {}\n",
3437 "file2.rs": "fn function2() {}\n",
3438 "file3.rs": "fn function3() {}\n",
3439 "file4.rs": "fn function4() {}\n",
3440 }),
3441 )
3442 .await;
3443
3444 let (_, _thread_store, thread, context_store, model) =
3445 setup_test_environment(cx, project.clone()).await;
3446
3447 // First message with context 1
3448 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3449 .await
3450 .unwrap();
3451 let new_contexts = context_store.update(cx, |store, cx| {
3452 store.new_context_for_thread(thread.read(cx), None)
3453 });
3454 assert_eq!(new_contexts.len(), 1);
3455 let loaded_context = cx
3456 .update(|cx| load_context(new_contexts, &project, &None, cx))
3457 .await;
3458 let message1_id = thread.update(cx, |thread, cx| {
3459 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3460 });
3461
3462 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3463 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3464 .await
3465 .unwrap();
3466 let new_contexts = context_store.update(cx, |store, cx| {
3467 store.new_context_for_thread(thread.read(cx), None)
3468 });
3469 assert_eq!(new_contexts.len(), 1);
3470 let loaded_context = cx
3471 .update(|cx| load_context(new_contexts, &project, &None, cx))
3472 .await;
3473 let message2_id = thread.update(cx, |thread, cx| {
3474 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3475 });
3476
3477 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3478 //
3479 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3480 .await
3481 .unwrap();
3482 let new_contexts = context_store.update(cx, |store, cx| {
3483 store.new_context_for_thread(thread.read(cx), None)
3484 });
3485 assert_eq!(new_contexts.len(), 1);
3486 let loaded_context = cx
3487 .update(|cx| load_context(new_contexts, &project, &None, cx))
3488 .await;
3489 let message3_id = thread.update(cx, |thread, cx| {
3490 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3491 });
3492
3493 // Check what contexts are included in each message
3494 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3495 (
3496 thread.message(message1_id).unwrap().clone(),
3497 thread.message(message2_id).unwrap().clone(),
3498 thread.message(message3_id).unwrap().clone(),
3499 )
3500 });
3501
3502 // First message should include context 1
3503 assert!(message1.loaded_context.text.contains("file1.rs"));
3504
3505 // Second message should include only context 2 (not 1)
3506 assert!(!message2.loaded_context.text.contains("file1.rs"));
3507 assert!(message2.loaded_context.text.contains("file2.rs"));
3508
3509 // Third message should include only context 3 (not 1 or 2)
3510 assert!(!message3.loaded_context.text.contains("file1.rs"));
3511 assert!(!message3.loaded_context.text.contains("file2.rs"));
3512 assert!(message3.loaded_context.text.contains("file3.rs"));
3513
3514 // Check entire request to make sure all contexts are properly included
3515 let request = thread.update(cx, |thread, cx| {
3516 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3517 });
3518
3519 // The request should contain all 3 messages
3520 assert_eq!(request.messages.len(), 4);
3521
3522 // Check that the contexts are properly formatted in each message
3523 assert!(request.messages[1].string_contents().contains("file1.rs"));
3524 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3525 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3526
3527 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3528 assert!(request.messages[2].string_contents().contains("file2.rs"));
3529 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3530
3531 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3532 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3533 assert!(request.messages[3].string_contents().contains("file3.rs"));
3534
3535 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3536 .await
3537 .unwrap();
3538 let new_contexts = context_store.update(cx, |store, cx| {
3539 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3540 });
3541 assert_eq!(new_contexts.len(), 3);
3542 let loaded_context = cx
3543 .update(|cx| load_context(new_contexts, &project, &None, cx))
3544 .await
3545 .loaded_context;
3546
3547 assert!(!loaded_context.text.contains("file1.rs"));
3548 assert!(loaded_context.text.contains("file2.rs"));
3549 assert!(loaded_context.text.contains("file3.rs"));
3550 assert!(loaded_context.text.contains("file4.rs"));
3551
3552 let new_contexts = context_store.update(cx, |store, cx| {
3553 // Remove file4.rs
3554 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3555 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3556 });
3557 assert_eq!(new_contexts.len(), 2);
3558 let loaded_context = cx
3559 .update(|cx| load_context(new_contexts, &project, &None, cx))
3560 .await
3561 .loaded_context;
3562
3563 assert!(!loaded_context.text.contains("file1.rs"));
3564 assert!(loaded_context.text.contains("file2.rs"));
3565 assert!(loaded_context.text.contains("file3.rs"));
3566 assert!(!loaded_context.text.contains("file4.rs"));
3567
3568 let new_contexts = context_store.update(cx, |store, cx| {
3569 // Remove file3.rs
3570 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3571 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3572 });
3573 assert_eq!(new_contexts.len(), 1);
3574 let loaded_context = cx
3575 .update(|cx| load_context(new_contexts, &project, &None, cx))
3576 .await
3577 .loaded_context;
3578
3579 assert!(!loaded_context.text.contains("file1.rs"));
3580 assert!(loaded_context.text.contains("file2.rs"));
3581 assert!(!loaded_context.text.contains("file3.rs"));
3582 assert!(!loaded_context.text.contains("file4.rs"));
3583 }
3584
3585 #[gpui::test]
3586 async fn test_message_without_files(cx: &mut TestAppContext) {
3587 init_test_settings(cx);
3588
3589 let project = create_test_project(
3590 cx,
3591 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3592 )
3593 .await;
3594
3595 let (_, _thread_store, thread, _context_store, model) =
3596 setup_test_environment(cx, project.clone()).await;
3597
3598 // Insert user message without any context (empty context vector)
3599 let message_id = thread.update(cx, |thread, cx| {
3600 thread.insert_user_message(
3601 "What is the best way to learn Rust?",
3602 ContextLoadResult::default(),
3603 None,
3604 Vec::new(),
3605 cx,
3606 )
3607 });
3608
3609 // Check content and context in message object
3610 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3611
3612 // Context should be empty when no files are included
3613 assert_eq!(message.role, Role::User);
3614 assert_eq!(message.segments.len(), 1);
3615 assert_eq!(
3616 message.segments[0],
3617 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3618 );
3619 assert_eq!(message.loaded_context.text, "");
3620
3621 // Check message in request
3622 let request = thread.update(cx, |thread, cx| {
3623 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3624 });
3625
3626 assert_eq!(request.messages.len(), 2);
3627 assert_eq!(
3628 request.messages[1].string_contents(),
3629 "What is the best way to learn Rust?"
3630 );
3631
3632 // Add second message, also without context
3633 let message2_id = thread.update(cx, |thread, cx| {
3634 thread.insert_user_message(
3635 "Are there any good books?",
3636 ContextLoadResult::default(),
3637 None,
3638 Vec::new(),
3639 cx,
3640 )
3641 });
3642
3643 let message2 =
3644 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3645 assert_eq!(message2.loaded_context.text, "");
3646
3647 // Check that both messages appear in the request
3648 let request = thread.update(cx, |thread, cx| {
3649 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3650 });
3651
3652 assert_eq!(request.messages.len(), 3);
3653 assert_eq!(
3654 request.messages[1].string_contents(),
3655 "What is the best way to learn Rust?"
3656 );
3657 assert_eq!(
3658 request.messages[2].string_contents(),
3659 "Are there any good books?"
3660 );
3661 }
3662
3663 #[gpui::test]
3664 #[ignore] // turn this test on when project_notifications tool is re-enabled
3665 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3666 init_test_settings(cx);
3667
3668 let project = create_test_project(
3669 cx,
3670 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3671 )
3672 .await;
3673
3674 let (_workspace, _thread_store, thread, context_store, model) =
3675 setup_test_environment(cx, project.clone()).await;
3676
3677 // Add a buffer to the context. This will be a tracked buffer
3678 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3679 .await
3680 .unwrap();
3681
3682 let context = context_store
3683 .read_with(cx, |store, _| store.context().next().cloned())
3684 .unwrap();
3685 let loaded_context = cx
3686 .update(|cx| load_context(vec![context], &project, &None, cx))
3687 .await;
3688
3689 // Insert user message and assistant response
3690 thread.update(cx, |thread, cx| {
3691 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3692 thread.insert_assistant_message(
3693 vec![MessageSegment::Text("This code prints 42.".into())],
3694 cx,
3695 );
3696 });
3697 cx.run_until_parked();
3698
3699 // We shouldn't have a stale buffer notification yet
3700 let notifications = thread.read_with(cx, |thread, _| {
3701 find_tool_uses(thread, "project_notifications")
3702 });
3703 assert!(
3704 notifications.is_empty(),
3705 "Should not have stale buffer notification before buffer is modified"
3706 );
3707
3708 // Modify the buffer
3709 buffer.update(cx, |buffer, cx| {
3710 buffer.edit(
3711 [(1..1, "\n println!(\"Added a new line\");\n")],
3712 None,
3713 cx,
3714 );
3715 });
3716
3717 // Insert another user message
3718 thread.update(cx, |thread, cx| {
3719 thread.insert_user_message(
3720 "What does the code do now?",
3721 ContextLoadResult::default(),
3722 None,
3723 Vec::new(),
3724 cx,
3725 )
3726 });
3727 cx.run_until_parked();
3728
3729 // Check for the stale buffer warning
3730 thread.update(cx, |thread, cx| {
3731 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3732 });
3733 cx.run_until_parked();
3734
3735 let notifications = thread.read_with(cx, |thread, _cx| {
3736 find_tool_uses(thread, "project_notifications")
3737 });
3738
3739 let [notification] = notifications.as_slice() else {
3740 panic!("Should have a `project_notifications` tool use");
3741 };
3742
3743 let Some(notification_content) = notification.content.to_str() else {
3744 panic!("`project_notifications` should return text");
3745 };
3746
3747 assert!(notification_content.contains("These files have changed since the last read:"));
3748 assert!(notification_content.contains("code.rs"));
3749
3750 // Insert another user message and flush notifications again
3751 thread.update(cx, |thread, cx| {
3752 thread.insert_user_message(
3753 "Can you tell me more?",
3754 ContextLoadResult::default(),
3755 None,
3756 Vec::new(),
3757 cx,
3758 )
3759 });
3760
3761 thread.update(cx, |thread, cx| {
3762 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3763 });
3764 cx.run_until_parked();
3765
3766 // There should be no new notifications (we already flushed one)
3767 let notifications = thread.read_with(cx, |thread, _cx| {
3768 find_tool_uses(thread, "project_notifications")
3769 });
3770
3771 assert_eq!(
3772 notifications.len(),
3773 1,
3774 "Should still have only one notification after second flush - no duplicates"
3775 );
3776 }
3777
3778 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3779 thread
3780 .messages()
3781 .flat_map(|message| {
3782 thread
3783 .tool_results_for_message(message.id)
3784 .into_iter()
3785 .filter(|result| result.tool_name == tool_name.into())
3786 .cloned()
3787 .collect::<Vec<_>>()
3788 })
3789 .collect()
3790 }
3791
3792 #[gpui::test]
3793 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3794 init_test_settings(cx);
3795
3796 let project = create_test_project(
3797 cx,
3798 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3799 )
3800 .await;
3801
3802 let (_workspace, thread_store, thread, _context_store, _model) =
3803 setup_test_environment(cx, project.clone()).await;
3804
3805 // Check that we are starting with the default profile
3806 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3807 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3808 assert_eq!(
3809 profile,
3810 AgentProfile::new(AgentProfileId::default(), tool_set)
3811 );
3812 }
3813
3814 #[gpui::test]
3815 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3816 init_test_settings(cx);
3817
3818 let project = create_test_project(
3819 cx,
3820 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3821 )
3822 .await;
3823
3824 let (_workspace, thread_store, thread, _context_store, _model) =
3825 setup_test_environment(cx, project.clone()).await;
3826
3827 // Profile gets serialized with default values
3828 let serialized = thread
3829 .update(cx, |thread, cx| thread.serialize(cx))
3830 .await
3831 .unwrap();
3832
3833 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3834
3835 let deserialized = cx.update(|cx| {
3836 thread.update(cx, |thread, cx| {
3837 Thread::deserialize(
3838 thread.id.clone(),
3839 serialized,
3840 thread.project.clone(),
3841 thread.tools.clone(),
3842 thread.prompt_builder.clone(),
3843 thread.project_context.clone(),
3844 None,
3845 cx,
3846 )
3847 })
3848 });
3849 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3850
3851 assert_eq!(
3852 deserialized.profile,
3853 AgentProfile::new(AgentProfileId::default(), tool_set)
3854 );
3855 }
3856
3857 #[gpui::test]
3858 async fn test_temperature_setting(cx: &mut TestAppContext) {
3859 init_test_settings(cx);
3860
3861 let project = create_test_project(
3862 cx,
3863 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3864 )
3865 .await;
3866
3867 let (_workspace, _thread_store, thread, _context_store, model) =
3868 setup_test_environment(cx, project.clone()).await;
3869
3870 // Both model and provider
3871 cx.update(|cx| {
3872 AgentSettings::override_global(
3873 AgentSettings {
3874 model_parameters: vec![LanguageModelParameters {
3875 provider: Some(model.provider_id().0.to_string().into()),
3876 model: Some(model.id().0.clone()),
3877 temperature: Some(0.66),
3878 }],
3879 ..AgentSettings::get_global(cx).clone()
3880 },
3881 cx,
3882 );
3883 });
3884
3885 let request = thread.update(cx, |thread, cx| {
3886 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3887 });
3888 assert_eq!(request.temperature, Some(0.66));
3889
3890 // Only model
3891 cx.update(|cx| {
3892 AgentSettings::override_global(
3893 AgentSettings {
3894 model_parameters: vec![LanguageModelParameters {
3895 provider: None,
3896 model: Some(model.id().0.clone()),
3897 temperature: Some(0.66),
3898 }],
3899 ..AgentSettings::get_global(cx).clone()
3900 },
3901 cx,
3902 );
3903 });
3904
3905 let request = thread.update(cx, |thread, cx| {
3906 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3907 });
3908 assert_eq!(request.temperature, Some(0.66));
3909
3910 // Only provider
3911 cx.update(|cx| {
3912 AgentSettings::override_global(
3913 AgentSettings {
3914 model_parameters: vec![LanguageModelParameters {
3915 provider: Some(model.provider_id().0.to_string().into()),
3916 model: None,
3917 temperature: Some(0.66),
3918 }],
3919 ..AgentSettings::get_global(cx).clone()
3920 },
3921 cx,
3922 );
3923 });
3924
3925 let request = thread.update(cx, |thread, cx| {
3926 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3927 });
3928 assert_eq!(request.temperature, Some(0.66));
3929
3930 // Same model name, different provider
3931 cx.update(|cx| {
3932 AgentSettings::override_global(
3933 AgentSettings {
3934 model_parameters: vec![LanguageModelParameters {
3935 provider: Some("anthropic".into()),
3936 model: Some(model.id().0.clone()),
3937 temperature: Some(0.66),
3938 }],
3939 ..AgentSettings::get_global(cx).clone()
3940 },
3941 cx,
3942 );
3943 });
3944
3945 let request = thread.update(cx, |thread, cx| {
3946 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3947 });
3948 assert_eq!(request.temperature, None);
3949 }
3950
3951 #[gpui::test]
3952 async fn test_thread_summary(cx: &mut TestAppContext) {
3953 init_test_settings(cx);
3954
3955 let project = create_test_project(cx, json!({})).await;
3956
3957 let (_, _thread_store, thread, _context_store, model) =
3958 setup_test_environment(cx, project.clone()).await;
3959
3960 // Initial state should be pending
3961 thread.read_with(cx, |thread, _| {
3962 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3963 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3964 });
3965
3966 // Manually setting the summary should not be allowed in this state
3967 thread.update(cx, |thread, cx| {
3968 thread.set_summary("This should not work", cx);
3969 });
3970
3971 thread.read_with(cx, |thread, _| {
3972 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3973 });
3974
3975 // Send a message
3976 thread.update(cx, |thread, cx| {
3977 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3978 thread.send_to_model(
3979 model.clone(),
3980 CompletionIntent::ThreadSummarization,
3981 None,
3982 cx,
3983 );
3984 });
3985
3986 let fake_model = model.as_fake();
3987 simulate_successful_response(fake_model, cx);
3988
3989 // Should start generating summary when there are >= 2 messages
3990 thread.read_with(cx, |thread, _| {
3991 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3992 });
3993
3994 // Should not be able to set the summary while generating
3995 thread.update(cx, |thread, cx| {
3996 thread.set_summary("This should not work either", cx);
3997 });
3998
3999 thread.read_with(cx, |thread, _| {
4000 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4001 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
4002 });
4003
4004 cx.run_until_parked();
4005 fake_model.send_last_completion_stream_text_chunk("Brief");
4006 fake_model.send_last_completion_stream_text_chunk(" Introduction");
4007 fake_model.end_last_completion_stream();
4008 cx.run_until_parked();
4009
4010 // Summary should be set
4011 thread.read_with(cx, |thread, _| {
4012 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4013 assert_eq!(thread.summary().or_default(), "Brief Introduction");
4014 });
4015
4016 // Now we should be able to set a summary
4017 thread.update(cx, |thread, cx| {
4018 thread.set_summary("Brief Intro", cx);
4019 });
4020
4021 thread.read_with(cx, |thread, _| {
4022 assert_eq!(thread.summary().or_default(), "Brief Intro");
4023 });
4024
4025 // Test setting an empty summary (should default to DEFAULT)
4026 thread.update(cx, |thread, cx| {
4027 thread.set_summary("", cx);
4028 });
4029
4030 thread.read_with(cx, |thread, _| {
4031 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4032 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
4033 });
4034 }
4035
4036 #[gpui::test]
4037 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
4038 init_test_settings(cx);
4039
4040 let project = create_test_project(cx, json!({})).await;
4041
4042 let (_, _thread_store, thread, _context_store, model) =
4043 setup_test_environment(cx, project.clone()).await;
4044
4045 test_summarize_error(&model, &thread, cx);
4046
4047 // Now we should be able to set a summary
4048 thread.update(cx, |thread, cx| {
4049 thread.set_summary("Brief Intro", cx);
4050 });
4051
4052 thread.read_with(cx, |thread, _| {
4053 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4054 assert_eq!(thread.summary().or_default(), "Brief Intro");
4055 });
4056 }
4057
4058 #[gpui::test]
4059 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
4060 init_test_settings(cx);
4061
4062 let project = create_test_project(cx, json!({})).await;
4063
4064 let (_, _thread_store, thread, _context_store, model) =
4065 setup_test_environment(cx, project.clone()).await;
4066
4067 test_summarize_error(&model, &thread, cx);
4068
4069 // Sending another message should not trigger another summarize request
4070 thread.update(cx, |thread, cx| {
4071 thread.insert_user_message(
4072 "How are you?",
4073 ContextLoadResult::default(),
4074 None,
4075 vec![],
4076 cx,
4077 );
4078 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4079 });
4080
4081 let fake_model = model.as_fake();
4082 simulate_successful_response(fake_model, cx);
4083
4084 thread.read_with(cx, |thread, _| {
4085 // State is still Error, not Generating
4086 assert!(matches!(thread.summary(), ThreadSummary::Error));
4087 });
4088
4089 // But the summarize request can be invoked manually
4090 thread.update(cx, |thread, cx| {
4091 thread.summarize(cx);
4092 });
4093
4094 thread.read_with(cx, |thread, _| {
4095 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4096 });
4097
4098 cx.run_until_parked();
4099 fake_model.send_last_completion_stream_text_chunk("A successful summary");
4100 fake_model.end_last_completion_stream();
4101 cx.run_until_parked();
4102
4103 thread.read_with(cx, |thread, _| {
4104 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4105 assert_eq!(thread.summary().or_default(), "A successful summary");
4106 });
4107 }
4108
4109 // Helper to create a model that returns errors
4110 enum TestError {
4111 Overloaded,
4112 InternalServerError,
4113 }
4114
4115 struct ErrorInjector {
4116 inner: Arc<FakeLanguageModel>,
4117 error_type: TestError,
4118 }
4119
4120 impl ErrorInjector {
4121 fn new(error_type: TestError) -> Self {
4122 Self {
4123 inner: Arc::new(FakeLanguageModel::default()),
4124 error_type,
4125 }
4126 }
4127 }
4128
4129 impl LanguageModel for ErrorInjector {
4130 fn id(&self) -> LanguageModelId {
4131 self.inner.id()
4132 }
4133
4134 fn name(&self) -> LanguageModelName {
4135 self.inner.name()
4136 }
4137
4138 fn provider_id(&self) -> LanguageModelProviderId {
4139 self.inner.provider_id()
4140 }
4141
4142 fn provider_name(&self) -> LanguageModelProviderName {
4143 self.inner.provider_name()
4144 }
4145
4146 fn supports_tools(&self) -> bool {
4147 self.inner.supports_tools()
4148 }
4149
4150 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4151 self.inner.supports_tool_choice(choice)
4152 }
4153
4154 fn supports_images(&self) -> bool {
4155 self.inner.supports_images()
4156 }
4157
4158 fn telemetry_id(&self) -> String {
4159 self.inner.telemetry_id()
4160 }
4161
4162 fn max_token_count(&self) -> u64 {
4163 self.inner.max_token_count()
4164 }
4165
4166 fn count_tokens(
4167 &self,
4168 request: LanguageModelRequest,
4169 cx: &App,
4170 ) -> BoxFuture<'static, Result<u64>> {
4171 self.inner.count_tokens(request, cx)
4172 }
4173
4174 fn stream_completion(
4175 &self,
4176 _request: LanguageModelRequest,
4177 _cx: &AsyncApp,
4178 ) -> BoxFuture<
4179 'static,
4180 Result<
4181 BoxStream<
4182 'static,
4183 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4184 >,
4185 LanguageModelCompletionError,
4186 >,
4187 > {
4188 let error = match self.error_type {
4189 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4190 provider: self.provider_name(),
4191 retry_after: None,
4192 },
4193 TestError::InternalServerError => {
4194 LanguageModelCompletionError::ApiInternalServerError {
4195 provider: self.provider_name(),
4196 message: "I'm a teapot orbiting the sun".to_string(),
4197 }
4198 }
4199 };
4200 async move {
4201 let stream = futures::stream::once(async move { Err(error) });
4202 Ok(stream.boxed())
4203 }
4204 .boxed()
4205 }
4206
4207 fn as_fake(&self) -> &FakeLanguageModel {
4208 &self.inner
4209 }
4210 }
4211
4212 #[gpui::test]
4213 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4214 init_test_settings(cx);
4215
4216 let project = create_test_project(cx, json!({})).await;
4217 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4218
4219 // Enable Burn Mode to allow retries
4220 thread.update(cx, |thread, _| {
4221 thread.set_completion_mode(CompletionMode::Burn);
4222 });
4223
4224 // Create model that returns overloaded error
4225 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4226
4227 // Insert a user message
4228 thread.update(cx, |thread, cx| {
4229 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4230 });
4231
4232 // Start completion
4233 thread.update(cx, |thread, cx| {
4234 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4235 });
4236
4237 cx.run_until_parked();
4238
4239 thread.read_with(cx, |thread, _| {
4240 assert!(thread.retry_state.is_some(), "Should have retry state");
4241 let retry_state = thread.retry_state.as_ref().unwrap();
4242 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4243 assert_eq!(
4244 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4245 "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4246 );
4247 });
4248
4249 // Check that a retry message was added
4250 thread.read_with(cx, |thread, _| {
4251 let mut messages = thread.messages();
4252 assert!(
4253 messages.any(|msg| {
4254 msg.role == Role::System
4255 && msg.ui_only
4256 && msg.segments.iter().any(|seg| {
4257 if let MessageSegment::Text(text) = seg {
4258 text.contains("overloaded")
4259 && text
4260 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4261 } else {
4262 false
4263 }
4264 })
4265 }),
4266 "Should have added a system retry message"
4267 );
4268 });
4269
4270 let retry_count = thread.update(cx, |thread, _| {
4271 thread
4272 .messages
4273 .iter()
4274 .filter(|m| {
4275 m.ui_only
4276 && m.segments.iter().any(|s| {
4277 if let MessageSegment::Text(text) = s {
4278 text.contains("Retrying") && text.contains("seconds")
4279 } else {
4280 false
4281 }
4282 })
4283 })
4284 .count()
4285 });
4286
4287 assert_eq!(retry_count, 1, "Should have one retry message");
4288 }
4289
4290 #[gpui::test]
4291 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4292 init_test_settings(cx);
4293
4294 let project = create_test_project(cx, json!({})).await;
4295 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4296
4297 // Enable Burn Mode to allow retries
4298 thread.update(cx, |thread, _| {
4299 thread.set_completion_mode(CompletionMode::Burn);
4300 });
4301
4302 // Create model that returns internal server error
4303 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4304
4305 // Insert a user message
4306 thread.update(cx, |thread, cx| {
4307 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4308 });
4309
4310 // Start completion
4311 thread.update(cx, |thread, cx| {
4312 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4313 });
4314
4315 cx.run_until_parked();
4316
4317 // Check retry state on thread
4318 thread.read_with(cx, |thread, _| {
4319 assert!(thread.retry_state.is_some(), "Should have retry state");
4320 let retry_state = thread.retry_state.as_ref().unwrap();
4321 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4322 assert_eq!(
4323 retry_state.max_attempts, 3,
4324 "Should have correct max attempts"
4325 );
4326 });
4327
4328 // Check that a retry message was added with provider name
4329 thread.read_with(cx, |thread, _| {
4330 let mut messages = thread.messages();
4331 assert!(
4332 messages.any(|msg| {
4333 msg.role == Role::System
4334 && msg.ui_only
4335 && msg.segments.iter().any(|seg| {
4336 if let MessageSegment::Text(text) = seg {
4337 text.contains("internal")
4338 && text.contains("Fake")
4339 && text.contains("Retrying")
4340 && text.contains("attempt 1 of 3")
4341 && text.contains("seconds")
4342 } else {
4343 false
4344 }
4345 })
4346 }),
4347 "Should have added a system retry message with provider name"
4348 );
4349 });
4350
4351 // Count retry messages
4352 let retry_count = thread.update(cx, |thread, _| {
4353 thread
4354 .messages
4355 .iter()
4356 .filter(|m| {
4357 m.ui_only
4358 && m.segments.iter().any(|s| {
4359 if let MessageSegment::Text(text) = s {
4360 text.contains("Retrying") && text.contains("seconds")
4361 } else {
4362 false
4363 }
4364 })
4365 })
4366 .count()
4367 });
4368
4369 assert_eq!(retry_count, 1, "Should have one retry message");
4370 }
4371
4372 #[gpui::test]
4373 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4374 init_test_settings(cx);
4375
4376 let project = create_test_project(cx, json!({})).await;
4377 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4378
4379 // Enable Burn Mode to allow retries
4380 thread.update(cx, |thread, _| {
4381 thread.set_completion_mode(CompletionMode::Burn);
4382 });
4383
4384 // Create model that returns internal server error
4385 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4386
4387 // Insert a user message
4388 thread.update(cx, |thread, cx| {
4389 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4390 });
4391
4392 // Track retry events and completion count
4393 // Track completion events
4394 let completion_count = Arc::new(Mutex::new(0));
4395 let completion_count_clone = completion_count.clone();
4396
4397 let _subscription = thread.update(cx, |_, cx| {
4398 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4399 if let ThreadEvent::NewRequest = event {
4400 *completion_count_clone.lock() += 1;
4401 }
4402 })
4403 });
4404
4405 // First attempt
4406 thread.update(cx, |thread, cx| {
4407 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4408 });
4409 cx.run_until_parked();
4410
4411 // Should have scheduled first retry - count retry messages
4412 let retry_count = thread.update(cx, |thread, _| {
4413 thread
4414 .messages
4415 .iter()
4416 .filter(|m| {
4417 m.ui_only
4418 && m.segments.iter().any(|s| {
4419 if let MessageSegment::Text(text) = s {
4420 text.contains("Retrying") && text.contains("seconds")
4421 } else {
4422 false
4423 }
4424 })
4425 })
4426 .count()
4427 });
4428 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4429
4430 // Check retry state
4431 thread.read_with(cx, |thread, _| {
4432 assert!(thread.retry_state.is_some(), "Should have retry state");
4433 let retry_state = thread.retry_state.as_ref().unwrap();
4434 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4435 assert_eq!(
4436 retry_state.max_attempts, 3,
4437 "Internal server errors should retry up to 3 times"
4438 );
4439 });
4440
4441 // Advance clock for first retry
4442 cx.executor().advance_clock(BASE_RETRY_DELAY);
4443 cx.run_until_parked();
4444
4445 // Advance clock for second retry
4446 cx.executor().advance_clock(BASE_RETRY_DELAY);
4447 cx.run_until_parked();
4448
4449 // Advance clock for third retry
4450 cx.executor().advance_clock(BASE_RETRY_DELAY);
4451 cx.run_until_parked();
4452
4453 // Should have completed all retries - count retry messages
4454 let retry_count = thread.update(cx, |thread, _| {
4455 thread
4456 .messages
4457 .iter()
4458 .filter(|m| {
4459 m.ui_only
4460 && m.segments.iter().any(|s| {
4461 if let MessageSegment::Text(text) = s {
4462 text.contains("Retrying") && text.contains("seconds")
4463 } else {
4464 false
4465 }
4466 })
4467 })
4468 .count()
4469 });
4470 assert_eq!(
4471 retry_count, 3,
4472 "Should have 3 retries for internal server errors"
4473 );
4474
4475 // For internal server errors, we retry 3 times and then give up
4476 // Check that retry_state is cleared after all retries
4477 thread.read_with(cx, |thread, _| {
4478 assert!(
4479 thread.retry_state.is_none(),
4480 "Retry state should be cleared after all retries"
4481 );
4482 });
4483
4484 // Verify total attempts (1 initial + 3 retries)
4485 assert_eq!(
4486 *completion_count.lock(),
4487 4,
4488 "Should have attempted once plus 3 retries"
4489 );
4490 }
4491
4492 #[gpui::test]
4493 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4494 init_test_settings(cx);
4495
4496 let project = create_test_project(cx, json!({})).await;
4497 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4498
4499 // Enable Burn Mode to allow retries
4500 thread.update(cx, |thread, _| {
4501 thread.set_completion_mode(CompletionMode::Burn);
4502 });
4503
4504 // Create model that returns overloaded error
4505 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4506
4507 // Insert a user message
4508 thread.update(cx, |thread, cx| {
4509 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4510 });
4511
4512 // Track events
4513 let stopped_with_error = Arc::new(Mutex::new(false));
4514 let stopped_with_error_clone = stopped_with_error.clone();
4515
4516 let _subscription = thread.update(cx, |_, cx| {
4517 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4518 if let ThreadEvent::Stopped(Err(_)) = event {
4519 *stopped_with_error_clone.lock() = true;
4520 }
4521 })
4522 });
4523
4524 // Start initial completion
4525 thread.update(cx, |thread, cx| {
4526 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4527 });
4528 cx.run_until_parked();
4529
4530 // Advance through all retries
4531 for _ in 0..MAX_RETRY_ATTEMPTS {
4532 cx.executor().advance_clock(BASE_RETRY_DELAY);
4533 cx.run_until_parked();
4534 }
4535
4536 let retry_count = thread.update(cx, |thread, _| {
4537 thread
4538 .messages
4539 .iter()
4540 .filter(|m| {
4541 m.ui_only
4542 && m.segments.iter().any(|s| {
4543 if let MessageSegment::Text(text) = s {
4544 text.contains("Retrying") && text.contains("seconds")
4545 } else {
4546 false
4547 }
4548 })
4549 })
4550 .count()
4551 });
4552
4553 // After max retries, should emit Stopped(Err(...)) event
4554 assert_eq!(
4555 retry_count, MAX_RETRY_ATTEMPTS as usize,
4556 "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4557 );
4558 assert!(
4559 *stopped_with_error.lock(),
4560 "Should emit Stopped(Err(...)) event after max retries exceeded"
4561 );
4562
4563 // Retry state should be cleared
4564 thread.read_with(cx, |thread, _| {
4565 assert!(
4566 thread.retry_state.is_none(),
4567 "Retry state should be cleared after max retries"
4568 );
4569
4570 // Verify we have the expected number of retry messages
4571 let retry_messages = thread
4572 .messages
4573 .iter()
4574 .filter(|msg| msg.ui_only && msg.role == Role::System)
4575 .count();
4576 assert_eq!(
4577 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4578 "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4579 );
4580 });
4581 }
4582
4583 #[gpui::test]
4584 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4585 init_test_settings(cx);
4586
4587 let project = create_test_project(cx, json!({})).await;
4588 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4589
4590 // Enable Burn Mode to allow retries
4591 thread.update(cx, |thread, _| {
4592 thread.set_completion_mode(CompletionMode::Burn);
4593 });
4594
4595 // We'll use a wrapper to switch behavior after first failure
4596 struct RetryTestModel {
4597 inner: Arc<FakeLanguageModel>,
4598 failed_once: Arc<Mutex<bool>>,
4599 }
4600
4601 impl LanguageModel for RetryTestModel {
4602 fn id(&self) -> LanguageModelId {
4603 self.inner.id()
4604 }
4605
4606 fn name(&self) -> LanguageModelName {
4607 self.inner.name()
4608 }
4609
4610 fn provider_id(&self) -> LanguageModelProviderId {
4611 self.inner.provider_id()
4612 }
4613
4614 fn provider_name(&self) -> LanguageModelProviderName {
4615 self.inner.provider_name()
4616 }
4617
4618 fn supports_tools(&self) -> bool {
4619 self.inner.supports_tools()
4620 }
4621
4622 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4623 self.inner.supports_tool_choice(choice)
4624 }
4625
4626 fn supports_images(&self) -> bool {
4627 self.inner.supports_images()
4628 }
4629
4630 fn telemetry_id(&self) -> String {
4631 self.inner.telemetry_id()
4632 }
4633
4634 fn max_token_count(&self) -> u64 {
4635 self.inner.max_token_count()
4636 }
4637
4638 fn count_tokens(
4639 &self,
4640 request: LanguageModelRequest,
4641 cx: &App,
4642 ) -> BoxFuture<'static, Result<u64>> {
4643 self.inner.count_tokens(request, cx)
4644 }
4645
4646 fn stream_completion(
4647 &self,
4648 request: LanguageModelRequest,
4649 cx: &AsyncApp,
4650 ) -> BoxFuture<
4651 'static,
4652 Result<
4653 BoxStream<
4654 'static,
4655 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4656 >,
4657 LanguageModelCompletionError,
4658 >,
4659 > {
4660 if !*self.failed_once.lock() {
4661 *self.failed_once.lock() = true;
4662 let provider = self.provider_name();
4663 // Return error on first attempt
4664 let stream = futures::stream::once(async move {
4665 Err(LanguageModelCompletionError::ServerOverloaded {
4666 provider,
4667 retry_after: None,
4668 })
4669 });
4670 async move { Ok(stream.boxed()) }.boxed()
4671 } else {
4672 // Succeed on retry
4673 self.inner.stream_completion(request, cx)
4674 }
4675 }
4676
4677 fn as_fake(&self) -> &FakeLanguageModel {
4678 &self.inner
4679 }
4680 }
4681
4682 let model = Arc::new(RetryTestModel {
4683 inner: Arc::new(FakeLanguageModel::default()),
4684 failed_once: Arc::new(Mutex::new(false)),
4685 });
4686
4687 // Insert a user message
4688 thread.update(cx, |thread, cx| {
4689 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4690 });
4691
4692 // Track message deletions
4693 // Track when retry completes successfully
4694 let retry_completed = Arc::new(Mutex::new(false));
4695 let retry_completed_clone = retry_completed.clone();
4696
4697 let _subscription = thread.update(cx, |_, cx| {
4698 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4699 if let ThreadEvent::StreamedCompletion = event {
4700 *retry_completed_clone.lock() = true;
4701 }
4702 })
4703 });
4704
4705 // Start completion
4706 thread.update(cx, |thread, cx| {
4707 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4708 });
4709 cx.run_until_parked();
4710
4711 // Get the retry message ID
4712 let retry_message_id = thread.read_with(cx, |thread, _| {
4713 thread
4714 .messages()
4715 .find(|msg| msg.role == Role::System && msg.ui_only)
4716 .map(|msg| msg.id)
4717 .expect("Should have a retry message")
4718 });
4719
4720 // Wait for retry
4721 cx.executor().advance_clock(BASE_RETRY_DELAY);
4722 cx.run_until_parked();
4723
4724 // Stream some successful content
4725 let fake_model = model.as_fake();
4726 // After the retry, there should be a new pending completion
4727 let pending = fake_model.pending_completions();
4728 assert!(
4729 !pending.is_empty(),
4730 "Should have a pending completion after retry"
4731 );
4732 fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
4733 fake_model.end_completion_stream(&pending[0]);
4734 cx.run_until_parked();
4735
4736 // Check that the retry completed successfully
4737 assert!(
4738 *retry_completed.lock(),
4739 "Retry should have completed successfully"
4740 );
4741
4742 // Retry message should still exist but be marked as ui_only
4743 thread.read_with(cx, |thread, _| {
4744 let retry_msg = thread
4745 .message(retry_message_id)
4746 .expect("Retry message should still exist");
4747 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4748 assert_eq!(
4749 retry_msg.role,
4750 Role::System,
4751 "Retry message should have System role"
4752 );
4753 });
4754 }
4755
4756 #[gpui::test]
4757 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4758 init_test_settings(cx);
4759
4760 let project = create_test_project(cx, json!({})).await;
4761 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4762
4763 // Enable Burn Mode to allow retries
4764 thread.update(cx, |thread, _| {
4765 thread.set_completion_mode(CompletionMode::Burn);
4766 });
4767
4768 // Create a model that fails once then succeeds
4769 struct FailOnceModel {
4770 inner: Arc<FakeLanguageModel>,
4771 failed_once: Arc<Mutex<bool>>,
4772 }
4773
4774 impl LanguageModel for FailOnceModel {
4775 fn id(&self) -> LanguageModelId {
4776 self.inner.id()
4777 }
4778
4779 fn name(&self) -> LanguageModelName {
4780 self.inner.name()
4781 }
4782
4783 fn provider_id(&self) -> LanguageModelProviderId {
4784 self.inner.provider_id()
4785 }
4786
4787 fn provider_name(&self) -> LanguageModelProviderName {
4788 self.inner.provider_name()
4789 }
4790
4791 fn supports_tools(&self) -> bool {
4792 self.inner.supports_tools()
4793 }
4794
4795 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4796 self.inner.supports_tool_choice(choice)
4797 }
4798
4799 fn supports_images(&self) -> bool {
4800 self.inner.supports_images()
4801 }
4802
4803 fn telemetry_id(&self) -> String {
4804 self.inner.telemetry_id()
4805 }
4806
4807 fn max_token_count(&self) -> u64 {
4808 self.inner.max_token_count()
4809 }
4810
4811 fn count_tokens(
4812 &self,
4813 request: LanguageModelRequest,
4814 cx: &App,
4815 ) -> BoxFuture<'static, Result<u64>> {
4816 self.inner.count_tokens(request, cx)
4817 }
4818
4819 fn stream_completion(
4820 &self,
4821 request: LanguageModelRequest,
4822 cx: &AsyncApp,
4823 ) -> BoxFuture<
4824 'static,
4825 Result<
4826 BoxStream<
4827 'static,
4828 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4829 >,
4830 LanguageModelCompletionError,
4831 >,
4832 > {
4833 if !*self.failed_once.lock() {
4834 *self.failed_once.lock() = true;
4835 let provider = self.provider_name();
4836 // Return error on first attempt
4837 let stream = futures::stream::once(async move {
4838 Err(LanguageModelCompletionError::ServerOverloaded {
4839 provider,
4840 retry_after: None,
4841 })
4842 });
4843 async move { Ok(stream.boxed()) }.boxed()
4844 } else {
4845 // Succeed on retry
4846 self.inner.stream_completion(request, cx)
4847 }
4848 }
4849 }
4850
4851 let fail_once_model = Arc::new(FailOnceModel {
4852 inner: Arc::new(FakeLanguageModel::default()),
4853 failed_once: Arc::new(Mutex::new(false)),
4854 });
4855
4856 // Insert a user message
4857 thread.update(cx, |thread, cx| {
4858 thread.insert_user_message(
4859 "Test message",
4860 ContextLoadResult::default(),
4861 None,
4862 vec![],
4863 cx,
4864 );
4865 });
4866
4867 // Start completion with fail-once model
4868 thread.update(cx, |thread, cx| {
4869 thread.send_to_model(
4870 fail_once_model.clone(),
4871 CompletionIntent::UserPrompt,
4872 None,
4873 cx,
4874 );
4875 });
4876
4877 cx.run_until_parked();
4878
4879 // Verify retry state exists after first failure
4880 thread.read_with(cx, |thread, _| {
4881 assert!(
4882 thread.retry_state.is_some(),
4883 "Should have retry state after failure"
4884 );
4885 });
4886
4887 // Wait for retry delay
4888 cx.executor().advance_clock(BASE_RETRY_DELAY);
4889 cx.run_until_parked();
4890
4891 // The retry should now use our FailOnceModel which should succeed
4892 // We need to help the FakeLanguageModel complete the stream
4893 let inner_fake = fail_once_model.inner.clone();
4894
4895 // Wait a bit for the retry to start
4896 cx.run_until_parked();
4897
4898 // Check for pending completions and complete them
4899 if let Some(pending) = inner_fake.pending_completions().first() {
4900 inner_fake.send_completion_stream_text_chunk(pending, "Success!");
4901 inner_fake.end_completion_stream(pending);
4902 }
4903 cx.run_until_parked();
4904
4905 thread.read_with(cx, |thread, _| {
4906 assert!(
4907 thread.retry_state.is_none(),
4908 "Retry state should be cleared after successful completion"
4909 );
4910
4911 let has_assistant_message = thread
4912 .messages
4913 .iter()
4914 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4915 assert!(
4916 has_assistant_message,
4917 "Should have an assistant message after successful retry"
4918 );
4919 });
4920 }
4921
4922 #[gpui::test]
4923 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4924 init_test_settings(cx);
4925
4926 let project = create_test_project(cx, json!({})).await;
4927 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4928
4929 // Enable Burn Mode to allow retries
4930 thread.update(cx, |thread, _| {
4931 thread.set_completion_mode(CompletionMode::Burn);
4932 });
4933
4934 // Create a model that returns rate limit error with retry_after
4935 struct RateLimitModel {
4936 inner: Arc<FakeLanguageModel>,
4937 }
4938
4939 impl LanguageModel for RateLimitModel {
4940 fn id(&self) -> LanguageModelId {
4941 self.inner.id()
4942 }
4943
4944 fn name(&self) -> LanguageModelName {
4945 self.inner.name()
4946 }
4947
4948 fn provider_id(&self) -> LanguageModelProviderId {
4949 self.inner.provider_id()
4950 }
4951
4952 fn provider_name(&self) -> LanguageModelProviderName {
4953 self.inner.provider_name()
4954 }
4955
4956 fn supports_tools(&self) -> bool {
4957 self.inner.supports_tools()
4958 }
4959
4960 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4961 self.inner.supports_tool_choice(choice)
4962 }
4963
4964 fn supports_images(&self) -> bool {
4965 self.inner.supports_images()
4966 }
4967
4968 fn telemetry_id(&self) -> String {
4969 self.inner.telemetry_id()
4970 }
4971
4972 fn max_token_count(&self) -> u64 {
4973 self.inner.max_token_count()
4974 }
4975
4976 fn count_tokens(
4977 &self,
4978 request: LanguageModelRequest,
4979 cx: &App,
4980 ) -> BoxFuture<'static, Result<u64>> {
4981 self.inner.count_tokens(request, cx)
4982 }
4983
4984 fn stream_completion(
4985 &self,
4986 _request: LanguageModelRequest,
4987 _cx: &AsyncApp,
4988 ) -> BoxFuture<
4989 'static,
4990 Result<
4991 BoxStream<
4992 'static,
4993 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4994 >,
4995 LanguageModelCompletionError,
4996 >,
4997 > {
4998 let provider = self.provider_name();
4999 async move {
5000 let stream = futures::stream::once(async move {
5001 Err(LanguageModelCompletionError::RateLimitExceeded {
5002 provider,
5003 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
5004 })
5005 });
5006 Ok(stream.boxed())
5007 }
5008 .boxed()
5009 }
5010
5011 fn as_fake(&self) -> &FakeLanguageModel {
5012 &self.inner
5013 }
5014 }
5015
5016 let model = Arc::new(RateLimitModel {
5017 inner: Arc::new(FakeLanguageModel::default()),
5018 });
5019
5020 // Insert a user message
5021 thread.update(cx, |thread, cx| {
5022 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5023 });
5024
5025 // Start completion
5026 thread.update(cx, |thread, cx| {
5027 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5028 });
5029
5030 cx.run_until_parked();
5031
5032 let retry_count = thread.update(cx, |thread, _| {
5033 thread
5034 .messages
5035 .iter()
5036 .filter(|m| {
5037 m.ui_only
5038 && m.segments.iter().any(|s| {
5039 if let MessageSegment::Text(text) = s {
5040 text.contains("rate limit exceeded")
5041 } else {
5042 false
5043 }
5044 })
5045 })
5046 .count()
5047 });
5048 assert_eq!(retry_count, 1, "Should have scheduled one retry");
5049
5050 thread.read_with(cx, |thread, _| {
5051 assert!(
5052 thread.retry_state.is_some(),
5053 "Rate limit errors should set retry_state"
5054 );
5055 if let Some(retry_state) = &thread.retry_state {
5056 assert_eq!(
5057 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
5058 "Rate limit errors should use MAX_RETRY_ATTEMPTS"
5059 );
5060 }
5061 });
5062
5063 // Verify we have one retry message
5064 thread.read_with(cx, |thread, _| {
5065 let retry_messages = thread
5066 .messages
5067 .iter()
5068 .filter(|msg| {
5069 msg.ui_only
5070 && msg.segments.iter().any(|seg| {
5071 if let MessageSegment::Text(text) = seg {
5072 text.contains("rate limit exceeded")
5073 } else {
5074 false
5075 }
5076 })
5077 })
5078 .count();
5079 assert_eq!(
5080 retry_messages, 1,
5081 "Should have one rate limit retry message"
5082 );
5083 });
5084
5085 // Check that retry message doesn't include attempt count
5086 thread.read_with(cx, |thread, _| {
5087 let retry_message = thread
5088 .messages
5089 .iter()
5090 .find(|msg| msg.role == Role::System && msg.ui_only)
5091 .expect("Should have a retry message");
5092
5093 // Check that the message contains attempt count since we use retry_state
5094 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5095 assert!(
5096 text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
5097 "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
5098 );
5099 assert!(
5100 text.contains("Retrying"),
5101 "Rate limit retry message should contain retry text"
5102 );
5103 }
5104 });
5105 }
5106
5107 #[gpui::test]
5108 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5109 init_test_settings(cx);
5110
5111 let project = create_test_project(cx, json!({})).await;
5112 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5113
5114 // Insert a regular user message
5115 thread.update(cx, |thread, cx| {
5116 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5117 });
5118
5119 // Insert a UI-only message (like our retry notifications)
5120 thread.update(cx, |thread, cx| {
5121 let id = thread.next_message_id.post_inc();
5122 thread.messages.push(Message {
5123 id,
5124 role: Role::System,
5125 segments: vec![MessageSegment::Text(
5126 "This is a UI-only message that should not be sent to the model".to_string(),
5127 )],
5128 loaded_context: LoadedContext::default(),
5129 creases: Vec::new(),
5130 is_hidden: true,
5131 ui_only: true,
5132 });
5133 cx.emit(ThreadEvent::MessageAdded(id));
5134 });
5135
5136 // Insert another regular message
5137 thread.update(cx, |thread, cx| {
5138 thread.insert_user_message(
5139 "How are you?",
5140 ContextLoadResult::default(),
5141 None,
5142 vec![],
5143 cx,
5144 );
5145 });
5146
5147 // Generate the completion request
5148 let request = thread.update(cx, |thread, cx| {
5149 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5150 });
5151
5152 // Verify that the request only contains non-UI-only messages
5153 // Should have system prompt + 2 user messages, but not the UI-only message
5154 let user_messages: Vec<_> = request
5155 .messages
5156 .iter()
5157 .filter(|msg| msg.role == Role::User)
5158 .collect();
5159 assert_eq!(
5160 user_messages.len(),
5161 2,
5162 "Should have exactly 2 user messages"
5163 );
5164
5165 // Verify the UI-only content is not present anywhere in the request
5166 let request_text = request
5167 .messages
5168 .iter()
5169 .flat_map(|msg| &msg.content)
5170 .filter_map(|content| match content {
5171 MessageContent::Text(text) => Some(text.as_str()),
5172 _ => None,
5173 })
5174 .collect::<String>();
5175
5176 assert!(
5177 !request_text.contains("UI-only message"),
5178 "UI-only message content should not be in the request"
5179 );
5180
5181 // Verify the thread still has all 3 messages (including UI-only)
5182 thread.read_with(cx, |thread, _| {
5183 assert_eq!(
5184 thread.messages().count(),
5185 3,
5186 "Thread should have 3 messages"
5187 );
5188 assert_eq!(
5189 thread.messages().filter(|m| m.ui_only).count(),
5190 1,
5191 "Thread should have 1 UI-only message"
5192 );
5193 });
5194
5195 // Verify that UI-only messages are not serialized
5196 let serialized = thread
5197 .update(cx, |thread, cx| thread.serialize(cx))
5198 .await
5199 .unwrap();
5200 assert_eq!(
5201 serialized.messages.len(),
5202 2,
5203 "Serialized thread should only have 2 messages (no UI-only)"
5204 );
5205 }
5206
5207 #[gpui::test]
5208 async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
5209 init_test_settings(cx);
5210
5211 let project = create_test_project(cx, json!({})).await;
5212 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5213
5214 // Ensure we're in Normal mode (not Burn mode)
5215 thread.update(cx, |thread, _| {
5216 thread.set_completion_mode(CompletionMode::Normal);
5217 });
5218
5219 // Track error events
5220 let error_events = Arc::new(Mutex::new(Vec::new()));
5221 let error_events_clone = error_events.clone();
5222
5223 let _subscription = thread.update(cx, |_, cx| {
5224 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
5225 if let ThreadEvent::ShowError(error) = event {
5226 error_events_clone.lock().push(error.clone());
5227 }
5228 })
5229 });
5230
5231 // Create model that returns overloaded error
5232 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5233
5234 // Insert a user message
5235 thread.update(cx, |thread, cx| {
5236 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5237 });
5238
5239 // Start completion
5240 thread.update(cx, |thread, cx| {
5241 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5242 });
5243
5244 cx.run_until_parked();
5245
5246 // Verify no retry state was created
5247 thread.read_with(cx, |thread, _| {
5248 assert!(
5249 thread.retry_state.is_none(),
5250 "Should not have retry state in Normal mode"
5251 );
5252 });
5253
5254 // Check that a retryable error was reported
5255 let errors = error_events.lock();
5256 assert!(!errors.is_empty(), "Should have received an error event");
5257
5258 if let ThreadError::RetryableError {
5259 message: _,
5260 can_enable_burn_mode,
5261 } = &errors[0]
5262 {
5263 assert!(
5264 *can_enable_burn_mode,
5265 "Error should indicate burn mode can be enabled"
5266 );
5267 } else {
5268 panic!("Expected RetryableError, got {:?}", errors[0]);
5269 }
5270
5271 // Verify the thread is no longer generating
5272 thread.read_with(cx, |thread, _| {
5273 assert!(
5274 !thread.is_generating(),
5275 "Should not be generating after error without retry"
5276 );
5277 });
5278 }
5279
5280 #[gpui::test]
5281 async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
5282 init_test_settings(cx);
5283
5284 let project = create_test_project(cx, json!({})).await;
5285 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5286
5287 // Enable Burn Mode to allow retries
5288 thread.update(cx, |thread, _| {
5289 thread.set_completion_mode(CompletionMode::Burn);
5290 });
5291
5292 // Create model that returns overloaded error
5293 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5294
5295 // Insert a user message
5296 thread.update(cx, |thread, cx| {
5297 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5298 });
5299
5300 // Start completion
5301 thread.update(cx, |thread, cx| {
5302 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5303 });
5304
5305 cx.run_until_parked();
5306
5307 // Verify retry was scheduled by checking for retry message
5308 let has_retry_message = thread.read_with(cx, |thread, _| {
5309 thread.messages.iter().any(|m| {
5310 m.ui_only
5311 && m.segments.iter().any(|s| {
5312 if let MessageSegment::Text(text) = s {
5313 text.contains("Retrying") && text.contains("seconds")
5314 } else {
5315 false
5316 }
5317 })
5318 })
5319 });
5320 assert!(has_retry_message, "Should have scheduled a retry");
5321
5322 // Cancel the completion before the retry happens
5323 thread.update(cx, |thread, cx| {
5324 thread.cancel_last_completion(None, cx);
5325 });
5326
5327 cx.run_until_parked();
5328
5329 // The retry should not have happened - no pending completions
5330 let fake_model = model.as_fake();
5331 assert_eq!(
5332 fake_model.pending_completions().len(),
5333 0,
5334 "Should have no pending completions after cancellation"
5335 );
5336
5337 // Verify the retry was canceled by checking retry state
5338 thread.read_with(cx, |thread, _| {
5339 if let Some(retry_state) = &thread.retry_state {
5340 panic!(
5341 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5342 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5343 );
5344 }
5345 });
5346 }
5347
5348 fn test_summarize_error(
5349 model: &Arc<dyn LanguageModel>,
5350 thread: &Entity<Thread>,
5351 cx: &mut TestAppContext,
5352 ) {
5353 thread.update(cx, |thread, cx| {
5354 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5355 thread.send_to_model(
5356 model.clone(),
5357 CompletionIntent::ThreadSummarization,
5358 None,
5359 cx,
5360 );
5361 });
5362
5363 let fake_model = model.as_fake();
5364 simulate_successful_response(fake_model, cx);
5365
5366 thread.read_with(cx, |thread, _| {
5367 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5368 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5369 });
5370
5371 // Simulate summary request ending
5372 cx.run_until_parked();
5373 fake_model.end_last_completion_stream();
5374 cx.run_until_parked();
5375
5376 // State is set to Error and default message
5377 thread.read_with(cx, |thread, _| {
5378 assert!(matches!(thread.summary(), ThreadSummary::Error));
5379 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5380 });
5381 }
5382
5383 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5384 cx.run_until_parked();
5385 fake_model.send_last_completion_stream_text_chunk("Assistant response");
5386 fake_model.end_last_completion_stream();
5387 cx.run_until_parked();
5388 }
5389
5390 fn init_test_settings(cx: &mut TestAppContext) {
5391 cx.update(|cx| {
5392 let settings_store = SettingsStore::test(cx);
5393 cx.set_global(settings_store);
5394 language::init(cx);
5395 Project::init_settings(cx);
5396 AgentSettings::register(cx);
5397 prompt_store::init(cx);
5398 thread_store::init(cx);
5399 workspace::init_settings(cx);
5400 language_model::init_settings(cx);
5401 ThemeSettings::register(cx);
5402 ToolRegistry::default_global(cx);
5403 assistant_tool::init(cx);
5404
5405 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5406 http_client::FakeHttpClient::with_200_response(),
5407 "http://localhost".to_string(),
5408 None,
5409 ));
5410 assistant_tools::init(http_client, cx);
5411 });
5412 }
5413
5414 // Helper to create a test project with test files
5415 async fn create_test_project(
5416 cx: &mut TestAppContext,
5417 files: serde_json::Value,
5418 ) -> Entity<Project> {
5419 let fs = FakeFs::new(cx.executor());
5420 fs.insert_tree(path!("/test"), files).await;
5421 Project::test(fs, [path!("/test").as_ref()], cx).await
5422 }
5423
5424 async fn setup_test_environment(
5425 cx: &mut TestAppContext,
5426 project: Entity<Project>,
5427 ) -> (
5428 Entity<Workspace>,
5429 Entity<ThreadStore>,
5430 Entity<Thread>,
5431 Entity<ContextStore>,
5432 Arc<dyn LanguageModel>,
5433 ) {
5434 let (workspace, cx) =
5435 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5436
5437 let thread_store = cx
5438 .update(|_, cx| {
5439 ThreadStore::load(
5440 project.clone(),
5441 cx.new(|_| ToolWorkingSet::default()),
5442 None,
5443 Arc::new(PromptBuilder::new(None).unwrap()),
5444 cx,
5445 )
5446 })
5447 .await
5448 .unwrap();
5449
5450 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5451 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5452
5453 let provider = Arc::new(FakeLanguageModelProvider::default());
5454 let model = provider.test_model();
5455 let model: Arc<dyn LanguageModel> = Arc::new(model);
5456
5457 cx.update(|_, cx| {
5458 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5459 registry.set_default_model(
5460 Some(ConfiguredModel {
5461 provider: provider.clone(),
5462 model: model.clone(),
5463 }),
5464 cx,
5465 );
5466 registry.set_thread_summary_model(
5467 Some(ConfiguredModel {
5468 provider,
5469 model: model.clone(),
5470 }),
5471 cx,
5472 );
5473 })
5474 });
5475
5476 (workspace, thread_store, thread, context_store, model)
5477 }
5478
5479 async fn add_file_to_context(
5480 project: &Entity<Project>,
5481 context_store: &Entity<ContextStore>,
5482 path: &str,
5483 cx: &mut TestAppContext,
5484 ) -> Result<Entity<language::Buffer>> {
5485 let buffer_path = project
5486 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5487 .unwrap();
5488
5489 let buffer = project
5490 .update(cx, |project, cx| {
5491 project.open_buffer(buffer_path.clone(), cx)
5492 })
5493 .await
5494 .unwrap();
5495
5496 context_store.update(cx, |context_store, cx| {
5497 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5498 });
5499
5500 Ok(buffer)
5501 }
5502}