1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use action_log::ActionLog;
12use agent_settings::{
13 AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
14 SUMMARIZE_THREAD_PROMPT,
15};
16use anyhow::{Result, anyhow};
17use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
18use chrono::{DateTime, Utc};
19use client::{ModelRequestUsage, RequestUsage};
20use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
21use collections::HashMap;
22use futures::{FutureExt, StreamExt as _, future::Shared};
23use git::repository::DiffType;
24use gpui::{
25 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
26 WeakEntity, Window,
27};
28use http_client::StatusCode;
29use language_model::{
30 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
31 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
32 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
33 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
34 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
35 TokenUsage,
36};
37use postage::stream::Stream as _;
38use project::{
39 Project,
40 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
41};
42use prompt_store::{ModelContext, PromptBuilder};
43use schemars::JsonSchema;
44use serde::{Deserialize, Serialize};
45use settings::Settings;
46use std::{
47 io::Write,
48 ops::Range,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use thiserror::Error;
53use util::{ResultExt as _, post_inc};
54use uuid::Uuid;
55
56const MAX_RETRY_ATTEMPTS: u8 = 4;
57const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
58
59#[derive(Debug, Clone)]
60enum RetryStrategy {
61 ExponentialBackoff {
62 initial_delay: Duration,
63 max_attempts: u8,
64 },
65 Fixed {
66 delay: Duration,
67 max_attempts: u8,
68 },
69}
70
71#[derive(
72 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
73)]
74pub struct ThreadId(Arc<str>);
75
76impl ThreadId {
77 pub fn new() -> Self {
78 Self(Uuid::new_v4().to_string().into())
79 }
80}
81
82impl std::fmt::Display for ThreadId {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.0)
85 }
86}
87
88impl From<&str> for ThreadId {
89 fn from(value: &str) -> Self {
90 Self(value.into())
91 }
92}
93
94/// The ID of the user prompt that initiated a request.
95///
96/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
98pub struct PromptId(Arc<str>);
99
100impl PromptId {
101 pub fn new() -> Self {
102 Self(Uuid::new_v4().to_string().into())
103 }
104}
105
106impl std::fmt::Display for PromptId {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
113pub struct MessageId(pub usize);
114
115impl MessageId {
116 fn post_inc(&mut self) -> Self {
117 Self(post_inc(&mut self.0))
118 }
119
120 pub fn as_usize(&self) -> usize {
121 self.0
122 }
123}
124
125/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
126#[derive(Clone, Debug)]
127pub struct MessageCrease {
128 pub range: Range<usize>,
129 pub icon_path: SharedString,
130 pub label: SharedString,
131 /// None for a deserialized message, Some otherwise.
132 pub context: Option<AgentContextHandle>,
133}
134
135/// A message in a [`Thread`].
136#[derive(Debug, Clone)]
137pub struct Message {
138 pub id: MessageId,
139 pub role: Role,
140 pub segments: Vec<MessageSegment>,
141 pub loaded_context: LoadedContext,
142 pub creases: Vec<MessageCrease>,
143 pub is_hidden: bool,
144 pub ui_only: bool,
145}
146
147impl Message {
148 /// Returns whether the message contains any meaningful text that should be displayed
149 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
150 pub fn should_display_content(&self) -> bool {
151 self.segments.iter().all(|segment| segment.should_display())
152 }
153
154 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
155 if let Some(MessageSegment::Thinking {
156 text: segment,
157 signature: current_signature,
158 }) = self.segments.last_mut()
159 {
160 if let Some(signature) = signature {
161 *current_signature = Some(signature);
162 }
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Thinking {
166 text: text.to_string(),
167 signature,
168 });
169 }
170 }
171
172 pub fn push_redacted_thinking(&mut self, data: String) {
173 self.segments.push(MessageSegment::RedactedThinking(data));
174 }
175
176 pub fn push_text(&mut self, text: &str) {
177 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
178 segment.push_str(text);
179 } else {
180 self.segments.push(MessageSegment::Text(text.to_string()));
181 }
182 }
183
184 pub fn to_message_content(&self) -> String {
185 let mut result = String::new();
186
187 if !self.loaded_context.text.is_empty() {
188 result.push_str(&self.loaded_context.text);
189 }
190
191 for segment in &self.segments {
192 match segment {
193 MessageSegment::Text(text) => result.push_str(text),
194 MessageSegment::Thinking { text, .. } => {
195 result.push_str("<think>\n");
196 result.push_str(text);
197 result.push_str("\n</think>");
198 }
199 MessageSegment::RedactedThinking(_) => {}
200 }
201 }
202
203 result
204 }
205}
206
207#[derive(Debug, Clone, PartialEq, Eq)]
208pub enum MessageSegment {
209 Text(String),
210 Thinking {
211 text: String,
212 signature: Option<String>,
213 },
214 RedactedThinking(String),
215}
216
217impl MessageSegment {
218 pub fn should_display(&self) -> bool {
219 match self {
220 Self::Text(text) => text.is_empty(),
221 Self::Thinking { text, .. } => text.is_empty(),
222 Self::RedactedThinking(_) => false,
223 }
224 }
225
226 pub fn text(&self) -> Option<&str> {
227 match self {
228 MessageSegment::Text(text) => Some(text),
229 _ => None,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
235pub struct ProjectSnapshot {
236 pub worktree_snapshots: Vec<WorktreeSnapshot>,
237 pub timestamp: DateTime<Utc>,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
241pub struct WorktreeSnapshot {
242 pub worktree_path: String,
243 pub git_state: Option<GitState>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
247pub struct GitState {
248 pub remote_url: Option<String>,
249 pub head_sha: Option<String>,
250 pub current_branch: Option<String>,
251 pub diff: Option<String>,
252}
253
254#[derive(Clone, Debug)]
255pub struct ThreadCheckpoint {
256 message_id: MessageId,
257 git_checkpoint: GitStoreCheckpoint,
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261pub enum ThreadFeedback {
262 Positive,
263 Negative,
264}
265
266pub enum LastRestoreCheckpoint {
267 Pending {
268 message_id: MessageId,
269 },
270 Error {
271 message_id: MessageId,
272 error: String,
273 },
274}
275
276impl LastRestoreCheckpoint {
277 pub fn message_id(&self) -> MessageId {
278 match self {
279 LastRestoreCheckpoint::Pending { message_id } => *message_id,
280 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
281 }
282 }
283}
284
285#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
286pub enum DetailedSummaryState {
287 #[default]
288 NotGenerated,
289 Generating {
290 message_id: MessageId,
291 },
292 Generated {
293 text: SharedString,
294 message_id: MessageId,
295 },
296}
297
298impl DetailedSummaryState {
299 fn text(&self) -> Option<SharedString> {
300 if let Self::Generated { text, .. } = self {
301 Some(text.clone())
302 } else {
303 None
304 }
305 }
306}
307
308#[derive(Default, Debug)]
309pub struct TotalTokenUsage {
310 pub total: u64,
311 pub max: u64,
312}
313
314impl TotalTokenUsage {
315 pub fn ratio(&self) -> TokenUsageRatio {
316 #[cfg(debug_assertions)]
317 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
318 .unwrap_or("0.8".to_string())
319 .parse()
320 .unwrap();
321 #[cfg(not(debug_assertions))]
322 let warning_threshold: f32 = 0.8;
323
324 // When the maximum is unknown because there is no selected model,
325 // avoid showing the token limit warning.
326 if self.max == 0 {
327 TokenUsageRatio::Normal
328 } else if self.total >= self.max {
329 TokenUsageRatio::Exceeded
330 } else if self.total as f32 / self.max as f32 >= warning_threshold {
331 TokenUsageRatio::Warning
332 } else {
333 TokenUsageRatio::Normal
334 }
335 }
336
337 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
338 TotalTokenUsage {
339 total: self.total + tokens,
340 max: self.max,
341 }
342 }
343}
344
345#[derive(Debug, Default, PartialEq, Eq)]
346pub enum TokenUsageRatio {
347 #[default]
348 Normal,
349 Warning,
350 Exceeded,
351}
352
353#[derive(Debug, Clone, Copy)]
354pub enum QueueState {
355 Sending,
356 Queued { position: usize },
357 Started,
358}
359
360/// A thread of conversation with the LLM.
361pub struct Thread {
362 id: ThreadId,
363 updated_at: DateTime<Utc>,
364 summary: ThreadSummary,
365 pending_summary: Task<Option<()>>,
366 detailed_summary_task: Task<Option<()>>,
367 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
368 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
369 completion_mode: agent_settings::CompletionMode,
370 messages: Vec<Message>,
371 next_message_id: MessageId,
372 last_prompt_id: PromptId,
373 project_context: SharedProjectContext,
374 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
375 completion_count: usize,
376 pending_completions: Vec<PendingCompletion>,
377 project: Entity<Project>,
378 prompt_builder: Arc<PromptBuilder>,
379 tools: Entity<ToolWorkingSet>,
380 tool_use: ToolUseState,
381 action_log: Entity<ActionLog>,
382 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
383 pending_checkpoint: Option<ThreadCheckpoint>,
384 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
385 request_token_usage: Vec<TokenUsage>,
386 cumulative_token_usage: TokenUsage,
387 exceeded_window_error: Option<ExceededWindowError>,
388 tool_use_limit_reached: bool,
389 retry_state: Option<RetryState>,
390 message_feedback: HashMap<MessageId, ThreadFeedback>,
391 last_received_chunk_at: Option<Instant>,
392 request_callback: Option<
393 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
394 >,
395 remaining_turns: u32,
396 configured_model: Option<ConfiguredModel>,
397 profile: AgentProfile,
398 last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
399}
400
401#[derive(Clone, Debug)]
402struct RetryState {
403 attempt: u8,
404 max_attempts: u8,
405 intent: CompletionIntent,
406}
407
408#[derive(Clone, Debug, PartialEq, Eq)]
409pub enum ThreadSummary {
410 Pending,
411 Generating,
412 Ready(SharedString),
413 Error,
414}
415
416impl ThreadSummary {
417 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
418
419 pub fn or_default(&self) -> SharedString {
420 self.unwrap_or(Self::DEFAULT)
421 }
422
423 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
424 self.ready().unwrap_or_else(|| message.into())
425 }
426
427 pub fn ready(&self) -> Option<SharedString> {
428 match self {
429 ThreadSummary::Ready(summary) => Some(summary.clone()),
430 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
436pub struct ExceededWindowError {
437 /// Model used when last message exceeded context window
438 model_id: LanguageModelId,
439 /// Token count including last message
440 token_count: u64,
441}
442
443impl Thread {
444 pub fn new(
445 project: Entity<Project>,
446 tools: Entity<ToolWorkingSet>,
447 prompt_builder: Arc<PromptBuilder>,
448 system_prompt: SharedProjectContext,
449 cx: &mut Context<Self>,
450 ) -> Self {
451 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
452 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
453 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
454
455 Self {
456 id: ThreadId::new(),
457 updated_at: Utc::now(),
458 summary: ThreadSummary::Pending,
459 pending_summary: Task::ready(None),
460 detailed_summary_task: Task::ready(None),
461 detailed_summary_tx,
462 detailed_summary_rx,
463 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
464 messages: Vec::new(),
465 next_message_id: MessageId(0),
466 last_prompt_id: PromptId::new(),
467 project_context: system_prompt,
468 checkpoints_by_message: HashMap::default(),
469 completion_count: 0,
470 pending_completions: Vec::new(),
471 project: project.clone(),
472 prompt_builder,
473 tools: tools.clone(),
474 last_restore_checkpoint: None,
475 pending_checkpoint: None,
476 tool_use: ToolUseState::new(tools.clone()),
477 action_log: cx.new(|_| ActionLog::new(project.clone())),
478 initial_project_snapshot: {
479 let project_snapshot = Self::project_snapshot(project, cx);
480 cx.foreground_executor()
481 .spawn(async move { Some(project_snapshot.await) })
482 .shared()
483 },
484 request_token_usage: Vec::new(),
485 cumulative_token_usage: TokenUsage::default(),
486 exceeded_window_error: None,
487 tool_use_limit_reached: false,
488 retry_state: None,
489 message_feedback: HashMap::default(),
490 last_error_context: None,
491 last_received_chunk_at: None,
492 request_callback: None,
493 remaining_turns: u32::MAX,
494 configured_model,
495 profile: AgentProfile::new(profile_id, tools),
496 }
497 }
498
499 pub fn deserialize(
500 id: ThreadId,
501 serialized: SerializedThread,
502 project: Entity<Project>,
503 tools: Entity<ToolWorkingSet>,
504 prompt_builder: Arc<PromptBuilder>,
505 project_context: SharedProjectContext,
506 window: Option<&mut Window>, // None in headless mode
507 cx: &mut Context<Self>,
508 ) -> Self {
509 let next_message_id = MessageId(
510 serialized
511 .messages
512 .last()
513 .map(|message| message.id.0 + 1)
514 .unwrap_or(0),
515 );
516 let tool_use = ToolUseState::from_serialized_messages(
517 tools.clone(),
518 &serialized.messages,
519 project.clone(),
520 window,
521 cx,
522 );
523 let (detailed_summary_tx, detailed_summary_rx) =
524 postage::watch::channel_with(serialized.detailed_summary_state);
525
526 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
527 serialized
528 .model
529 .and_then(|model| {
530 let model = SelectedModel {
531 provider: model.provider.clone().into(),
532 model: model.model.into(),
533 };
534 registry.select_model(&model, cx)
535 })
536 .or_else(|| registry.default_model())
537 });
538
539 let completion_mode = serialized
540 .completion_mode
541 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
542 let profile_id = serialized
543 .profile
544 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
545
546 Self {
547 id,
548 updated_at: serialized.updated_at,
549 summary: ThreadSummary::Ready(serialized.summary),
550 pending_summary: Task::ready(None),
551 detailed_summary_task: Task::ready(None),
552 detailed_summary_tx,
553 detailed_summary_rx,
554 completion_mode,
555 retry_state: None,
556 messages: serialized
557 .messages
558 .into_iter()
559 .map(|message| Message {
560 id: message.id,
561 role: message.role,
562 segments: message
563 .segments
564 .into_iter()
565 .map(|segment| match segment {
566 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
567 SerializedMessageSegment::Thinking { text, signature } => {
568 MessageSegment::Thinking { text, signature }
569 }
570 SerializedMessageSegment::RedactedThinking { data } => {
571 MessageSegment::RedactedThinking(data)
572 }
573 })
574 .collect(),
575 loaded_context: LoadedContext {
576 contexts: Vec::new(),
577 text: message.context,
578 images: Vec::new(),
579 },
580 creases: message
581 .creases
582 .into_iter()
583 .map(|crease| MessageCrease {
584 range: crease.start..crease.end,
585 icon_path: crease.icon_path,
586 label: crease.label,
587 context: None,
588 })
589 .collect(),
590 is_hidden: message.is_hidden,
591 ui_only: false, // UI-only messages are not persisted
592 })
593 .collect(),
594 next_message_id,
595 last_prompt_id: PromptId::new(),
596 project_context,
597 checkpoints_by_message: HashMap::default(),
598 completion_count: 0,
599 pending_completions: Vec::new(),
600 last_restore_checkpoint: None,
601 pending_checkpoint: None,
602 project: project.clone(),
603 prompt_builder,
604 tools: tools.clone(),
605 tool_use,
606 action_log: cx.new(|_| ActionLog::new(project)),
607 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
608 request_token_usage: serialized.request_token_usage,
609 cumulative_token_usage: serialized.cumulative_token_usage,
610 exceeded_window_error: None,
611 tool_use_limit_reached: serialized.tool_use_limit_reached,
612 message_feedback: HashMap::default(),
613 last_error_context: None,
614 last_received_chunk_at: None,
615 request_callback: None,
616 remaining_turns: u32::MAX,
617 configured_model,
618 profile: AgentProfile::new(profile_id, tools),
619 }
620 }
621
622 pub fn set_request_callback(
623 &mut self,
624 callback: impl 'static
625 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
626 ) {
627 self.request_callback = Some(Box::new(callback));
628 }
629
630 pub fn id(&self) -> &ThreadId {
631 &self.id
632 }
633
634 pub fn profile(&self) -> &AgentProfile {
635 &self.profile
636 }
637
638 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
639 if &id != self.profile.id() {
640 self.profile = AgentProfile::new(id, self.tools.clone());
641 cx.emit(ThreadEvent::ProfileChanged);
642 }
643 }
644
645 pub fn is_empty(&self) -> bool {
646 self.messages.is_empty()
647 }
648
649 pub fn updated_at(&self) -> DateTime<Utc> {
650 self.updated_at
651 }
652
653 pub fn touch_updated_at(&mut self) {
654 self.updated_at = Utc::now();
655 }
656
657 pub fn advance_prompt_id(&mut self) {
658 self.last_prompt_id = PromptId::new();
659 }
660
661 pub fn project_context(&self) -> SharedProjectContext {
662 self.project_context.clone()
663 }
664
665 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
666 if self.configured_model.is_none() {
667 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
668 }
669 self.configured_model.clone()
670 }
671
672 pub fn configured_model(&self) -> Option<ConfiguredModel> {
673 self.configured_model.clone()
674 }
675
676 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
677 self.configured_model = model;
678 cx.notify();
679 }
680
681 pub fn summary(&self) -> &ThreadSummary {
682 &self.summary
683 }
684
685 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
686 let current_summary = match &self.summary {
687 ThreadSummary::Pending | ThreadSummary::Generating => return,
688 ThreadSummary::Ready(summary) => summary,
689 ThreadSummary::Error => &ThreadSummary::DEFAULT,
690 };
691
692 let mut new_summary = new_summary.into();
693
694 if new_summary.is_empty() {
695 new_summary = ThreadSummary::DEFAULT;
696 }
697
698 if current_summary != &new_summary {
699 self.summary = ThreadSummary::Ready(new_summary);
700 cx.emit(ThreadEvent::SummaryChanged);
701 }
702 }
703
704 pub fn completion_mode(&self) -> CompletionMode {
705 self.completion_mode
706 }
707
708 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
709 self.completion_mode = mode;
710 }
711
712 pub fn message(&self, id: MessageId) -> Option<&Message> {
713 let index = self
714 .messages
715 .binary_search_by(|message| message.id.cmp(&id))
716 .ok()?;
717
718 self.messages.get(index)
719 }
720
721 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
722 self.messages.iter()
723 }
724
725 pub fn is_generating(&self) -> bool {
726 !self.pending_completions.is_empty() || !self.all_tools_finished()
727 }
728
729 /// Indicates whether streaming of language model events is stale.
730 /// When `is_generating()` is false, this method returns `None`.
731 pub fn is_generation_stale(&self) -> Option<bool> {
732 const STALE_THRESHOLD: u128 = 250;
733
734 self.last_received_chunk_at
735 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
736 }
737
738 fn received_chunk(&mut self) {
739 self.last_received_chunk_at = Some(Instant::now());
740 }
741
742 pub fn queue_state(&self) -> Option<QueueState> {
743 self.pending_completions
744 .first()
745 .map(|pending_completion| pending_completion.queue_state)
746 }
747
748 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
749 &self.tools
750 }
751
752 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
753 self.tool_use
754 .pending_tool_uses()
755 .into_iter()
756 .find(|tool_use| &tool_use.id == id)
757 }
758
759 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
760 self.tool_use
761 .pending_tool_uses()
762 .into_iter()
763 .filter(|tool_use| tool_use.status.needs_confirmation())
764 }
765
766 pub fn has_pending_tool_uses(&self) -> bool {
767 !self.tool_use.pending_tool_uses().is_empty()
768 }
769
770 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
771 self.checkpoints_by_message.get(&id).cloned()
772 }
773
774 pub fn restore_checkpoint(
775 &mut self,
776 checkpoint: ThreadCheckpoint,
777 cx: &mut Context<Self>,
778 ) -> Task<Result<()>> {
779 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
780 message_id: checkpoint.message_id,
781 });
782 cx.emit(ThreadEvent::CheckpointChanged);
783 cx.notify();
784
785 let git_store = self.project().read(cx).git_store().clone();
786 let restore = git_store.update(cx, |git_store, cx| {
787 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
788 });
789
790 cx.spawn(async move |this, cx| {
791 let result = restore.await;
792 this.update(cx, |this, cx| {
793 if let Err(err) = result.as_ref() {
794 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
795 message_id: checkpoint.message_id,
796 error: err.to_string(),
797 });
798 } else {
799 this.truncate(checkpoint.message_id, cx);
800 this.last_restore_checkpoint = None;
801 }
802 this.pending_checkpoint = None;
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 })?;
806 result
807 })
808 }
809
810 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
811 let pending_checkpoint = if self.is_generating() {
812 return;
813 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
814 checkpoint
815 } else {
816 return;
817 };
818
819 self.finalize_checkpoint(pending_checkpoint, cx);
820 }
821
822 fn finalize_checkpoint(
823 &mut self,
824 pending_checkpoint: ThreadCheckpoint,
825 cx: &mut Context<Self>,
826 ) {
827 let git_store = self.project.read(cx).git_store().clone();
828 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
829 cx.spawn(async move |this, cx| match final_checkpoint.await {
830 Ok(final_checkpoint) => {
831 let equal = git_store
832 .update(cx, |store, cx| {
833 store.compare_checkpoints(
834 pending_checkpoint.git_checkpoint.clone(),
835 final_checkpoint.clone(),
836 cx,
837 )
838 })?
839 .await
840 .unwrap_or(false);
841
842 this.update(cx, |this, cx| {
843 this.pending_checkpoint = if equal {
844 Some(pending_checkpoint)
845 } else {
846 this.insert_checkpoint(pending_checkpoint, cx);
847 Some(ThreadCheckpoint {
848 message_id: this.next_message_id,
849 git_checkpoint: final_checkpoint,
850 })
851 }
852 })?;
853
854 Ok(())
855 }
856 Err(_) => this.update(cx, |this, cx| {
857 this.insert_checkpoint(pending_checkpoint, cx)
858 }),
859 })
860 .detach();
861 }
862
863 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
864 self.checkpoints_by_message
865 .insert(checkpoint.message_id, checkpoint);
866 cx.emit(ThreadEvent::CheckpointChanged);
867 cx.notify();
868 }
869
870 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
871 self.last_restore_checkpoint.as_ref()
872 }
873
874 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
875 let Some(message_ix) = self
876 .messages
877 .iter()
878 .rposition(|message| message.id == message_id)
879 else {
880 return;
881 };
882 for deleted_message in self.messages.drain(message_ix..) {
883 self.checkpoints_by_message.remove(&deleted_message.id);
884 }
885 cx.notify();
886 }
887
888 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
889 self.messages
890 .iter()
891 .find(|message| message.id == id)
892 .into_iter()
893 .flat_map(|message| message.loaded_context.contexts.iter())
894 }
895
896 pub fn is_turn_end(&self, ix: usize) -> bool {
897 if self.messages.is_empty() {
898 return false;
899 }
900
901 if !self.is_generating() && ix == self.messages.len() - 1 {
902 return true;
903 }
904
905 let Some(message) = self.messages.get(ix) else {
906 return false;
907 };
908
909 if message.role != Role::Assistant {
910 return false;
911 }
912
913 self.messages
914 .get(ix + 1)
915 .and_then(|message| {
916 self.message(message.id)
917 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
918 })
919 .unwrap_or(false)
920 }
921
922 pub fn tool_use_limit_reached(&self) -> bool {
923 self.tool_use_limit_reached
924 }
925
926 /// Returns whether all of the tool uses have finished running.
927 pub fn all_tools_finished(&self) -> bool {
928 // If the only pending tool uses left are the ones with errors, then
929 // that means that we've finished running all of the pending tools.
930 self.tool_use
931 .pending_tool_uses()
932 .iter()
933 .all(|pending_tool_use| pending_tool_use.status.is_error())
934 }
935
936 /// Returns whether any pending tool uses may perform edits
937 pub fn has_pending_edit_tool_uses(&self) -> bool {
938 self.tool_use
939 .pending_tool_uses()
940 .iter()
941 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
942 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
943 }
944
945 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
946 self.tool_use.tool_uses_for_message(id, &self.project, cx)
947 }
948
949 pub fn tool_results_for_message(
950 &self,
951 assistant_message_id: MessageId,
952 ) -> Vec<&LanguageModelToolResult> {
953 self.tool_use.tool_results_for_message(assistant_message_id)
954 }
955
956 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
957 self.tool_use.tool_result(id)
958 }
959
960 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
961 match &self.tool_use.tool_result(id)?.content {
962 LanguageModelToolResultContent::Text(text) => Some(text),
963 LanguageModelToolResultContent::Image(_) => {
964 // TODO: We should display image
965 None
966 }
967 }
968 }
969
970 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
971 self.tool_use.tool_result_card(id).cloned()
972 }
973
974 /// Return tools that are both enabled and supported by the model
975 pub fn available_tools(
976 &self,
977 cx: &App,
978 model: Arc<dyn LanguageModel>,
979 ) -> Vec<LanguageModelRequestTool> {
980 if model.supports_tools() {
981 self.profile
982 .enabled_tools(cx)
983 .into_iter()
984 .filter_map(|(name, tool)| {
985 // Skip tools that cannot be supported
986 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
987 Some(LanguageModelRequestTool {
988 name: name.into(),
989 description: tool.description(),
990 input_schema,
991 })
992 })
993 .collect()
994 } else {
995 Vec::default()
996 }
997 }
998
999 pub fn insert_user_message(
1000 &mut self,
1001 text: impl Into<String>,
1002 loaded_context: ContextLoadResult,
1003 git_checkpoint: Option<GitStoreCheckpoint>,
1004 creases: Vec<MessageCrease>,
1005 cx: &mut Context<Self>,
1006 ) -> MessageId {
1007 if !loaded_context.referenced_buffers.is_empty() {
1008 self.action_log.update(cx, |log, cx| {
1009 for buffer in loaded_context.referenced_buffers {
1010 log.buffer_read(buffer, cx);
1011 }
1012 });
1013 }
1014
1015 let message_id = self.insert_message(
1016 Role::User,
1017 vec![MessageSegment::Text(text.into())],
1018 loaded_context.loaded_context,
1019 creases,
1020 false,
1021 cx,
1022 );
1023
1024 if let Some(git_checkpoint) = git_checkpoint {
1025 self.pending_checkpoint = Some(ThreadCheckpoint {
1026 message_id,
1027 git_checkpoint,
1028 });
1029 }
1030
1031 message_id
1032 }
1033
1034 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1035 let id = self.insert_message(
1036 Role::User,
1037 vec![MessageSegment::Text("Continue where you left off".into())],
1038 LoadedContext::default(),
1039 vec![],
1040 true,
1041 cx,
1042 );
1043 self.pending_checkpoint = None;
1044
1045 id
1046 }
1047
1048 pub fn insert_assistant_message(
1049 &mut self,
1050 segments: Vec<MessageSegment>,
1051 cx: &mut Context<Self>,
1052 ) -> MessageId {
1053 self.insert_message(
1054 Role::Assistant,
1055 segments,
1056 LoadedContext::default(),
1057 Vec::new(),
1058 false,
1059 cx,
1060 )
1061 }
1062
1063 pub fn insert_message(
1064 &mut self,
1065 role: Role,
1066 segments: Vec<MessageSegment>,
1067 loaded_context: LoadedContext,
1068 creases: Vec<MessageCrease>,
1069 is_hidden: bool,
1070 cx: &mut Context<Self>,
1071 ) -> MessageId {
1072 let id = self.next_message_id.post_inc();
1073 self.messages.push(Message {
1074 id,
1075 role,
1076 segments,
1077 loaded_context,
1078 creases,
1079 is_hidden,
1080 ui_only: false,
1081 });
1082 self.touch_updated_at();
1083 cx.emit(ThreadEvent::MessageAdded(id));
1084 id
1085 }
1086
1087 pub fn edit_message(
1088 &mut self,
1089 id: MessageId,
1090 new_role: Role,
1091 new_segments: Vec<MessageSegment>,
1092 creases: Vec<MessageCrease>,
1093 loaded_context: Option<LoadedContext>,
1094 checkpoint: Option<GitStoreCheckpoint>,
1095 cx: &mut Context<Self>,
1096 ) -> bool {
1097 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1098 return false;
1099 };
1100 message.role = new_role;
1101 message.segments = new_segments;
1102 message.creases = creases;
1103 if let Some(context) = loaded_context {
1104 message.loaded_context = context;
1105 }
1106 if let Some(git_checkpoint) = checkpoint {
1107 self.checkpoints_by_message.insert(
1108 id,
1109 ThreadCheckpoint {
1110 message_id: id,
1111 git_checkpoint,
1112 },
1113 );
1114 }
1115 self.touch_updated_at();
1116 cx.emit(ThreadEvent::MessageEdited(id));
1117 true
1118 }
1119
1120 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1121 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1122 return false;
1123 };
1124 self.messages.remove(index);
1125 self.touch_updated_at();
1126 cx.emit(ThreadEvent::MessageDeleted(id));
1127 true
1128 }
1129
1130 /// Returns the representation of this [`Thread`] in a textual form.
1131 ///
1132 /// This is the representation we use when attaching a thread as context to another thread.
1133 pub fn text(&self) -> String {
1134 let mut text = String::new();
1135
1136 for message in &self.messages {
1137 text.push_str(match message.role {
1138 language_model::Role::User => "User:",
1139 language_model::Role::Assistant => "Agent:",
1140 language_model::Role::System => "System:",
1141 });
1142 text.push('\n');
1143
1144 for segment in &message.segments {
1145 match segment {
1146 MessageSegment::Text(content) => text.push_str(content),
1147 MessageSegment::Thinking { text: content, .. } => {
1148 text.push_str(&format!("<think>{}</think>", content))
1149 }
1150 MessageSegment::RedactedThinking(_) => {}
1151 }
1152 }
1153 text.push('\n');
1154 }
1155
1156 text
1157 }
1158
1159 /// Serializes this thread into a format for storage or telemetry.
1160 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1161 let initial_project_snapshot = self.initial_project_snapshot.clone();
1162 cx.spawn(async move |this, cx| {
1163 let initial_project_snapshot = initial_project_snapshot.await;
1164 this.read_with(cx, |this, cx| SerializedThread {
1165 version: SerializedThread::VERSION.to_string(),
1166 summary: this.summary().or_default(),
1167 updated_at: this.updated_at(),
1168 messages: this
1169 .messages()
1170 .filter(|message| !message.ui_only)
1171 .map(|message| SerializedMessage {
1172 id: message.id,
1173 role: message.role,
1174 segments: message
1175 .segments
1176 .iter()
1177 .map(|segment| match segment {
1178 MessageSegment::Text(text) => {
1179 SerializedMessageSegment::Text { text: text.clone() }
1180 }
1181 MessageSegment::Thinking { text, signature } => {
1182 SerializedMessageSegment::Thinking {
1183 text: text.clone(),
1184 signature: signature.clone(),
1185 }
1186 }
1187 MessageSegment::RedactedThinking(data) => {
1188 SerializedMessageSegment::RedactedThinking {
1189 data: data.clone(),
1190 }
1191 }
1192 })
1193 .collect(),
1194 tool_uses: this
1195 .tool_uses_for_message(message.id, cx)
1196 .into_iter()
1197 .map(|tool_use| SerializedToolUse {
1198 id: tool_use.id,
1199 name: tool_use.name,
1200 input: tool_use.input,
1201 })
1202 .collect(),
1203 tool_results: this
1204 .tool_results_for_message(message.id)
1205 .into_iter()
1206 .map(|tool_result| SerializedToolResult {
1207 tool_use_id: tool_result.tool_use_id.clone(),
1208 is_error: tool_result.is_error,
1209 content: tool_result.content.clone(),
1210 output: tool_result.output.clone(),
1211 })
1212 .collect(),
1213 context: message.loaded_context.text.clone(),
1214 creases: message
1215 .creases
1216 .iter()
1217 .map(|crease| SerializedCrease {
1218 start: crease.range.start,
1219 end: crease.range.end,
1220 icon_path: crease.icon_path.clone(),
1221 label: crease.label.clone(),
1222 })
1223 .collect(),
1224 is_hidden: message.is_hidden,
1225 })
1226 .collect(),
1227 initial_project_snapshot,
1228 cumulative_token_usage: this.cumulative_token_usage,
1229 request_token_usage: this.request_token_usage.clone(),
1230 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1231 exceeded_window_error: this.exceeded_window_error.clone(),
1232 model: this
1233 .configured_model
1234 .as_ref()
1235 .map(|model| SerializedLanguageModel {
1236 provider: model.provider.id().0.to_string(),
1237 model: model.model.id().0.to_string(),
1238 }),
1239 completion_mode: Some(this.completion_mode),
1240 tool_use_limit_reached: this.tool_use_limit_reached,
1241 profile: Some(this.profile.id().clone()),
1242 })
1243 })
1244 }
1245
1246 pub fn remaining_turns(&self) -> u32 {
1247 self.remaining_turns
1248 }
1249
1250 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1251 self.remaining_turns = remaining_turns;
1252 }
1253
1254 pub fn send_to_model(
1255 &mut self,
1256 model: Arc<dyn LanguageModel>,
1257 intent: CompletionIntent,
1258 window: Option<AnyWindowHandle>,
1259 cx: &mut Context<Self>,
1260 ) {
1261 if self.remaining_turns == 0 {
1262 return;
1263 }
1264
1265 self.remaining_turns -= 1;
1266
1267 self.flush_notifications(model.clone(), intent, cx);
1268
1269 let _checkpoint = self.finalize_pending_checkpoint(cx);
1270 self.stream_completion(
1271 self.to_completion_request(model.clone(), intent, cx),
1272 model,
1273 intent,
1274 window,
1275 cx,
1276 );
1277 }
1278
1279 pub fn retry_last_completion(
1280 &mut self,
1281 window: Option<AnyWindowHandle>,
1282 cx: &mut Context<Self>,
1283 ) {
1284 // Clear any existing error state
1285 self.retry_state = None;
1286
1287 // Use the last error context if available, otherwise fall back to configured model
1288 let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() {
1289 (model, intent)
1290 } else if let Some(configured_model) = self.configured_model.as_ref() {
1291 let model = configured_model.model.clone();
1292 let intent = if self.has_pending_tool_uses() {
1293 CompletionIntent::ToolResults
1294 } else {
1295 CompletionIntent::UserPrompt
1296 };
1297 (model, intent)
1298 } else if let Some(configured_model) = self.get_or_init_configured_model(cx) {
1299 let model = configured_model.model.clone();
1300 let intent = if self.has_pending_tool_uses() {
1301 CompletionIntent::ToolResults
1302 } else {
1303 CompletionIntent::UserPrompt
1304 };
1305 (model, intent)
1306 } else {
1307 return;
1308 };
1309
1310 self.send_to_model(model, intent, window, cx);
1311 }
1312
1313 pub fn enable_burn_mode_and_retry(
1314 &mut self,
1315 window: Option<AnyWindowHandle>,
1316 cx: &mut Context<Self>,
1317 ) {
1318 self.completion_mode = CompletionMode::Burn;
1319 cx.emit(ThreadEvent::ProfileChanged);
1320 self.retry_last_completion(window, cx);
1321 }
1322
1323 pub fn used_tools_since_last_user_message(&self) -> bool {
1324 for message in self.messages.iter().rev() {
1325 if self.tool_use.message_has_tool_results(message.id) {
1326 return true;
1327 } else if message.role == Role::User {
1328 return false;
1329 }
1330 }
1331
1332 false
1333 }
1334
1335 pub fn to_completion_request(
1336 &self,
1337 model: Arc<dyn LanguageModel>,
1338 intent: CompletionIntent,
1339 cx: &mut Context<Self>,
1340 ) -> LanguageModelRequest {
1341 let mut request = LanguageModelRequest {
1342 thread_id: Some(self.id.to_string()),
1343 prompt_id: Some(self.last_prompt_id.to_string()),
1344 intent: Some(intent),
1345 mode: None,
1346 messages: vec![],
1347 tools: Vec::new(),
1348 tool_choice: None,
1349 stop: Vec::new(),
1350 temperature: AgentSettings::temperature_for_model(&model, cx),
1351 thinking_allowed: true,
1352 };
1353
1354 let available_tools = self.available_tools(cx, model.clone());
1355 let available_tool_names = available_tools
1356 .iter()
1357 .map(|tool| tool.name.clone())
1358 .collect();
1359
1360 let model_context = &ModelContext {
1361 available_tools: available_tool_names,
1362 };
1363
1364 if let Some(project_context) = self.project_context.borrow().as_ref() {
1365 match self
1366 .prompt_builder
1367 .generate_assistant_system_prompt(project_context, model_context)
1368 {
1369 Err(err) => {
1370 let message = format!("{err:?}").into();
1371 log::error!("{message}");
1372 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1373 header: "Error generating system prompt".into(),
1374 message,
1375 }));
1376 }
1377 Ok(system_prompt) => {
1378 request.messages.push(LanguageModelRequestMessage {
1379 role: Role::System,
1380 content: vec![MessageContent::Text(system_prompt)],
1381 cache: true,
1382 });
1383 }
1384 }
1385 } else {
1386 let message = "Context for system prompt unexpectedly not ready.".into();
1387 log::error!("{message}");
1388 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1389 header: "Error generating system prompt".into(),
1390 message,
1391 }));
1392 }
1393
1394 let mut message_ix_to_cache = None;
1395 for message in &self.messages {
1396 // ui_only messages are for the UI only, not for the model
1397 if message.ui_only {
1398 continue;
1399 }
1400
1401 let mut request_message = LanguageModelRequestMessage {
1402 role: message.role,
1403 content: Vec::new(),
1404 cache: false,
1405 };
1406
1407 message
1408 .loaded_context
1409 .add_to_request_message(&mut request_message);
1410
1411 for segment in &message.segments {
1412 match segment {
1413 MessageSegment::Text(text) => {
1414 let text = text.trim_end();
1415 if !text.is_empty() {
1416 request_message
1417 .content
1418 .push(MessageContent::Text(text.into()));
1419 }
1420 }
1421 MessageSegment::Thinking { text, signature } => {
1422 if !text.is_empty() {
1423 request_message.content.push(MessageContent::Thinking {
1424 text: text.into(),
1425 signature: signature.clone(),
1426 });
1427 }
1428 }
1429 MessageSegment::RedactedThinking(data) => {
1430 request_message
1431 .content
1432 .push(MessageContent::RedactedThinking(data.clone()));
1433 }
1434 };
1435 }
1436
1437 let mut cache_message = true;
1438 let mut tool_results_message = LanguageModelRequestMessage {
1439 role: Role::User,
1440 content: Vec::new(),
1441 cache: false,
1442 };
1443 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1444 if let Some(tool_result) = tool_result {
1445 request_message
1446 .content
1447 .push(MessageContent::ToolUse(tool_use.clone()));
1448 tool_results_message
1449 .content
1450 .push(MessageContent::ToolResult(LanguageModelToolResult {
1451 tool_use_id: tool_use.id.clone(),
1452 tool_name: tool_result.tool_name.clone(),
1453 is_error: tool_result.is_error,
1454 content: if tool_result.content.is_empty() {
1455 // Surprisingly, the API fails if we return an empty string here.
1456 // It thinks we are sending a tool use without a tool result.
1457 "<Tool returned an empty string>".into()
1458 } else {
1459 tool_result.content.clone()
1460 },
1461 output: None,
1462 }));
1463 } else {
1464 cache_message = false;
1465 log::debug!(
1466 "skipped tool use {:?} because it is still pending",
1467 tool_use
1468 );
1469 }
1470 }
1471
1472 if cache_message {
1473 message_ix_to_cache = Some(request.messages.len());
1474 }
1475 request.messages.push(request_message);
1476
1477 if !tool_results_message.content.is_empty() {
1478 if cache_message {
1479 message_ix_to_cache = Some(request.messages.len());
1480 }
1481 request.messages.push(tool_results_message);
1482 }
1483 }
1484
1485 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1486 if let Some(message_ix_to_cache) = message_ix_to_cache {
1487 request.messages[message_ix_to_cache].cache = true;
1488 }
1489
1490 request.tools = available_tools;
1491 request.mode = if model.supports_burn_mode() {
1492 Some(self.completion_mode.into())
1493 } else {
1494 Some(CompletionMode::Normal.into())
1495 };
1496
1497 request
1498 }
1499
1500 fn to_summarize_request(
1501 &self,
1502 model: &Arc<dyn LanguageModel>,
1503 intent: CompletionIntent,
1504 added_user_message: String,
1505 cx: &App,
1506 ) -> LanguageModelRequest {
1507 let mut request = LanguageModelRequest {
1508 thread_id: None,
1509 prompt_id: None,
1510 intent: Some(intent),
1511 mode: None,
1512 messages: vec![],
1513 tools: Vec::new(),
1514 tool_choice: None,
1515 stop: Vec::new(),
1516 temperature: AgentSettings::temperature_for_model(model, cx),
1517 thinking_allowed: false,
1518 };
1519
1520 for message in &self.messages {
1521 let mut request_message = LanguageModelRequestMessage {
1522 role: message.role,
1523 content: Vec::new(),
1524 cache: false,
1525 };
1526
1527 for segment in &message.segments {
1528 match segment {
1529 MessageSegment::Text(text) => request_message
1530 .content
1531 .push(MessageContent::Text(text.clone())),
1532 MessageSegment::Thinking { .. } => {}
1533 MessageSegment::RedactedThinking(_) => {}
1534 }
1535 }
1536
1537 if request_message.content.is_empty() {
1538 continue;
1539 }
1540
1541 request.messages.push(request_message);
1542 }
1543
1544 request.messages.push(LanguageModelRequestMessage {
1545 role: Role::User,
1546 content: vec![MessageContent::Text(added_user_message)],
1547 cache: false,
1548 });
1549
1550 request
1551 }
1552
1553 /// Insert auto-generated notifications (if any) to the thread
1554 fn flush_notifications(
1555 &mut self,
1556 model: Arc<dyn LanguageModel>,
1557 intent: CompletionIntent,
1558 cx: &mut Context<Self>,
1559 ) {
1560 match intent {
1561 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1562 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1563 cx.emit(ThreadEvent::ToolFinished {
1564 tool_use_id: pending_tool_use.id.clone(),
1565 pending_tool_use: Some(pending_tool_use),
1566 });
1567 }
1568 }
1569 CompletionIntent::ThreadSummarization
1570 | CompletionIntent::ThreadContextSummarization
1571 | CompletionIntent::CreateFile
1572 | CompletionIntent::EditFile
1573 | CompletionIntent::InlineAssist
1574 | CompletionIntent::TerminalInlineAssist
1575 | CompletionIntent::GenerateGitCommitMessage => {}
1576 };
1577 }
1578
1579 fn attach_tracked_files_state(
1580 &mut self,
1581 model: Arc<dyn LanguageModel>,
1582 cx: &mut App,
1583 ) -> Option<PendingToolUse> {
1584 // Represent notification as a simulated `project_notifications` tool call
1585 let tool_name = Arc::from("project_notifications");
1586 let tool = self.tools.read(cx).tool(&tool_name, cx)?;
1587
1588 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1589 return None;
1590 }
1591
1592 if self
1593 .action_log
1594 .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
1595 {
1596 return None;
1597 }
1598
1599 let input = serde_json::json!({});
1600 let request = Arc::new(LanguageModelRequest::default()); // unused
1601 let window = None;
1602 let tool_result = tool.run(
1603 input,
1604 request,
1605 self.project.clone(),
1606 self.action_log.clone(),
1607 model.clone(),
1608 window,
1609 cx,
1610 );
1611
1612 let tool_use_id =
1613 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1614
1615 let tool_use = LanguageModelToolUse {
1616 id: tool_use_id.clone(),
1617 name: tool_name.clone(),
1618 raw_input: "{}".to_string(),
1619 input: serde_json::json!({}),
1620 is_input_complete: true,
1621 };
1622
1623 let tool_output = cx.background_executor().block(tool_result.output);
1624
1625 // Attach a project_notification tool call to the latest existing
1626 // Assistant message. We cannot create a new Assistant message
1627 // because thinking models require a `thinking` block that we
1628 // cannot mock. We cannot send a notification as a normal
1629 // (non-tool-use) User message because this distracts Agent
1630 // too much.
1631 let tool_message_id = self
1632 .messages
1633 .iter()
1634 .enumerate()
1635 .rfind(|(_, message)| message.role == Role::Assistant)
1636 .map(|(_, message)| message.id)?;
1637
1638 let tool_use_metadata = ToolUseMetadata {
1639 model: model.clone(),
1640 thread_id: self.id.clone(),
1641 prompt_id: self.last_prompt_id.clone(),
1642 };
1643
1644 self.tool_use
1645 .request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx);
1646
1647 self.tool_use.insert_tool_output(
1648 tool_use_id,
1649 tool_name,
1650 tool_output,
1651 self.configured_model.as_ref(),
1652 self.completion_mode,
1653 )
1654 }
1655
1656 pub fn stream_completion(
1657 &mut self,
1658 request: LanguageModelRequest,
1659 model: Arc<dyn LanguageModel>,
1660 intent: CompletionIntent,
1661 window: Option<AnyWindowHandle>,
1662 cx: &mut Context<Self>,
1663 ) {
1664 self.tool_use_limit_reached = false;
1665
1666 let pending_completion_id = post_inc(&mut self.completion_count);
1667 let mut request_callback_parameters = if self.request_callback.is_some() {
1668 Some((request.clone(), Vec::new()))
1669 } else {
1670 None
1671 };
1672 let prompt_id = self.last_prompt_id.clone();
1673 let tool_use_metadata = ToolUseMetadata {
1674 model: model.clone(),
1675 thread_id: self.id.clone(),
1676 prompt_id: prompt_id.clone(),
1677 };
1678
1679 let completion_mode = request
1680 .mode
1681 .unwrap_or(cloud_llm_client::CompletionMode::Normal);
1682
1683 self.last_received_chunk_at = Some(Instant::now());
1684
1685 let task = cx.spawn(async move |thread, cx| {
1686 let stream_completion_future = model.stream_completion(request, cx);
1687 let initial_token_usage =
1688 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1689 let stream_completion = async {
1690 let mut events = stream_completion_future.await?;
1691
1692 let mut stop_reason = StopReason::EndTurn;
1693 let mut current_token_usage = TokenUsage::default();
1694
1695 thread
1696 .update(cx, |_thread, cx| {
1697 cx.emit(ThreadEvent::NewRequest);
1698 })
1699 .ok();
1700
1701 let mut request_assistant_message_id = None;
1702
1703 while let Some(event) = events.next().await {
1704 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1705 response_events
1706 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1707 }
1708
1709 thread.update(cx, |thread, cx| {
1710 match event? {
1711 LanguageModelCompletionEvent::StartMessage { .. } => {
1712 request_assistant_message_id =
1713 Some(thread.insert_assistant_message(
1714 vec![MessageSegment::Text(String::new())],
1715 cx,
1716 ));
1717 }
1718 LanguageModelCompletionEvent::Stop(reason) => {
1719 stop_reason = reason;
1720 }
1721 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1722 thread.update_token_usage_at_last_message(token_usage);
1723 thread.cumulative_token_usage = thread.cumulative_token_usage
1724 + token_usage
1725 - current_token_usage;
1726 current_token_usage = token_usage;
1727 }
1728 LanguageModelCompletionEvent::Text(chunk) => {
1729 thread.received_chunk();
1730
1731 cx.emit(ThreadEvent::ReceivedTextChunk);
1732 if let Some(last_message) = thread.messages.last_mut() {
1733 if last_message.role == Role::Assistant
1734 && !thread.tool_use.has_tool_results(last_message.id)
1735 {
1736 last_message.push_text(&chunk);
1737 cx.emit(ThreadEvent::StreamedAssistantText(
1738 last_message.id,
1739 chunk,
1740 ));
1741 } else {
1742 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1743 // of a new Assistant response.
1744 //
1745 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1746 // will result in duplicating the text of the chunk in the rendered Markdown.
1747 request_assistant_message_id =
1748 Some(thread.insert_assistant_message(
1749 vec![MessageSegment::Text(chunk.to_string())],
1750 cx,
1751 ));
1752 };
1753 }
1754 }
1755 LanguageModelCompletionEvent::Thinking {
1756 text: chunk,
1757 signature,
1758 } => {
1759 thread.received_chunk();
1760
1761 if let Some(last_message) = thread.messages.last_mut() {
1762 if last_message.role == Role::Assistant
1763 && !thread.tool_use.has_tool_results(last_message.id)
1764 {
1765 last_message.push_thinking(&chunk, signature);
1766 cx.emit(ThreadEvent::StreamedAssistantThinking(
1767 last_message.id,
1768 chunk,
1769 ));
1770 } else {
1771 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1772 // of a new Assistant response.
1773 //
1774 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1775 // will result in duplicating the text of the chunk in the rendered Markdown.
1776 request_assistant_message_id =
1777 Some(thread.insert_assistant_message(
1778 vec![MessageSegment::Thinking {
1779 text: chunk.to_string(),
1780 signature,
1781 }],
1782 cx,
1783 ));
1784 };
1785 }
1786 }
1787 LanguageModelCompletionEvent::RedactedThinking { data } => {
1788 thread.received_chunk();
1789
1790 if let Some(last_message) = thread.messages.last_mut() {
1791 if last_message.role == Role::Assistant
1792 && !thread.tool_use.has_tool_results(last_message.id)
1793 {
1794 last_message.push_redacted_thinking(data);
1795 } else {
1796 request_assistant_message_id =
1797 Some(thread.insert_assistant_message(
1798 vec![MessageSegment::RedactedThinking(data)],
1799 cx,
1800 ));
1801 };
1802 }
1803 }
1804 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1805 let last_assistant_message_id = request_assistant_message_id
1806 .unwrap_or_else(|| {
1807 let new_assistant_message_id =
1808 thread.insert_assistant_message(vec![], cx);
1809 request_assistant_message_id =
1810 Some(new_assistant_message_id);
1811 new_assistant_message_id
1812 });
1813
1814 let tool_use_id = tool_use.id.clone();
1815 let streamed_input = if tool_use.is_input_complete {
1816 None
1817 } else {
1818 Some(tool_use.input.clone())
1819 };
1820
1821 let ui_text = thread.tool_use.request_tool_use(
1822 last_assistant_message_id,
1823 tool_use,
1824 tool_use_metadata.clone(),
1825 cx,
1826 );
1827
1828 if let Some(input) = streamed_input {
1829 cx.emit(ThreadEvent::StreamedToolUse {
1830 tool_use_id,
1831 ui_text,
1832 input,
1833 });
1834 }
1835 }
1836 LanguageModelCompletionEvent::ToolUseJsonParseError {
1837 id,
1838 tool_name,
1839 raw_input: invalid_input_json,
1840 json_parse_error,
1841 } => {
1842 thread.receive_invalid_tool_json(
1843 id,
1844 tool_name,
1845 invalid_input_json,
1846 json_parse_error,
1847 window,
1848 cx,
1849 );
1850 }
1851 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1852 if let Some(completion) = thread
1853 .pending_completions
1854 .iter_mut()
1855 .find(|completion| completion.id == pending_completion_id)
1856 {
1857 match status_update {
1858 CompletionRequestStatus::Queued { position } => {
1859 completion.queue_state =
1860 QueueState::Queued { position };
1861 }
1862 CompletionRequestStatus::Started => {
1863 completion.queue_state = QueueState::Started;
1864 }
1865 CompletionRequestStatus::Failed {
1866 code,
1867 message,
1868 request_id: _,
1869 retry_after,
1870 } => {
1871 return Err(
1872 LanguageModelCompletionError::from_cloud_failure(
1873 model.upstream_provider_name(),
1874 code,
1875 message,
1876 retry_after.map(Duration::from_secs_f64),
1877 ),
1878 );
1879 }
1880 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1881 thread.update_model_request_usage(
1882 amount as u32,
1883 limit,
1884 cx,
1885 );
1886 }
1887 CompletionRequestStatus::ToolUseLimitReached => {
1888 thread.tool_use_limit_reached = true;
1889 cx.emit(ThreadEvent::ToolUseLimitReached);
1890 }
1891 }
1892 }
1893 }
1894 }
1895
1896 thread.touch_updated_at();
1897 cx.emit(ThreadEvent::StreamedCompletion);
1898 cx.notify();
1899
1900 Ok(())
1901 })??;
1902
1903 smol::future::yield_now().await;
1904 }
1905
1906 thread.update(cx, |thread, cx| {
1907 thread.last_received_chunk_at = None;
1908 thread
1909 .pending_completions
1910 .retain(|completion| completion.id != pending_completion_id);
1911
1912 // If there is a response without tool use, summarize the message. Otherwise,
1913 // allow two tool uses before summarizing.
1914 if matches!(thread.summary, ThreadSummary::Pending)
1915 && thread.messages.len() >= 2
1916 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1917 {
1918 thread.summarize(cx);
1919 }
1920 })?;
1921
1922 anyhow::Ok(stop_reason)
1923 };
1924
1925 let result = stream_completion.await;
1926 let mut retry_scheduled = false;
1927
1928 thread
1929 .update(cx, |thread, cx| {
1930 thread.finalize_pending_checkpoint(cx);
1931 match result.as_ref() {
1932 Ok(stop_reason) => {
1933 match stop_reason {
1934 StopReason::ToolUse => {
1935 let tool_uses =
1936 thread.use_pending_tools(window, model.clone(), cx);
1937 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1938 }
1939 StopReason::EndTurn | StopReason::MaxTokens => {
1940 thread.project.update(cx, |project, cx| {
1941 project.set_agent_location(None, cx);
1942 });
1943 }
1944 StopReason::Refusal => {
1945 thread.project.update(cx, |project, cx| {
1946 project.set_agent_location(None, cx);
1947 });
1948
1949 // Remove the turn that was refused.
1950 //
1951 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1952 {
1953 let mut messages_to_remove = Vec::new();
1954
1955 for (ix, message) in
1956 thread.messages.iter().enumerate().rev()
1957 {
1958 messages_to_remove.push(message.id);
1959
1960 if message.role == Role::User {
1961 if ix == 0 {
1962 break;
1963 }
1964
1965 if let Some(prev_message) =
1966 thread.messages.get(ix - 1)
1967 && prev_message.role == Role::Assistant {
1968 break;
1969 }
1970 }
1971 }
1972
1973 for message_id in messages_to_remove {
1974 thread.delete_message(message_id, cx);
1975 }
1976 }
1977
1978 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1979 header: "Language model refusal".into(),
1980 message:
1981 "Model refused to generate content for safety reasons."
1982 .into(),
1983 }));
1984 }
1985 }
1986
1987 // We successfully completed, so cancel any remaining retries.
1988 thread.retry_state = None;
1989 }
1990 Err(error) => {
1991 thread.project.update(cx, |project, cx| {
1992 project.set_agent_location(None, cx);
1993 });
1994
1995 if error.is::<PaymentRequiredError>() {
1996 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1997 } else if let Some(error) =
1998 error.downcast_ref::<ModelRequestLimitReachedError>()
1999 {
2000 cx.emit(ThreadEvent::ShowError(
2001 ThreadError::ModelRequestLimitReached { plan: error.plan },
2002 ));
2003 } else if let Some(completion_error) =
2004 error.downcast_ref::<LanguageModelCompletionError>()
2005 {
2006 match &completion_error {
2007 LanguageModelCompletionError::PromptTooLarge {
2008 tokens, ..
2009 } => {
2010 let tokens = tokens.unwrap_or_else(|| {
2011 // We didn't get an exact token count from the API, so fall back on our estimate.
2012 thread
2013 .total_token_usage()
2014 .map(|usage| usage.total)
2015 .unwrap_or(0)
2016 // We know the context window was exceeded in practice, so if our estimate was
2017 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
2018 .max(
2019 model
2020 .max_token_count_for_mode(completion_mode)
2021 .saturating_add(1),
2022 )
2023 });
2024 thread.exceeded_window_error = Some(ExceededWindowError {
2025 model_id: model.id(),
2026 token_count: tokens,
2027 });
2028 cx.notify();
2029 }
2030 _ => {
2031 if let Some(retry_strategy) =
2032 Thread::get_retry_strategy(completion_error)
2033 {
2034 log::info!(
2035 "Retrying with {:?} for language model completion error {:?}",
2036 retry_strategy,
2037 completion_error
2038 );
2039
2040 retry_scheduled = thread
2041 .handle_retryable_error_with_delay(
2042 completion_error,
2043 Some(retry_strategy),
2044 model.clone(),
2045 intent,
2046 window,
2047 cx,
2048 );
2049 }
2050 }
2051 }
2052 }
2053
2054 if !retry_scheduled {
2055 thread.cancel_last_completion(window, cx);
2056 }
2057 }
2058 }
2059
2060 if !retry_scheduled {
2061 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2062 }
2063
2064 if let Some((request_callback, (request, response_events))) = thread
2065 .request_callback
2066 .as_mut()
2067 .zip(request_callback_parameters.as_ref())
2068 {
2069 request_callback(request, response_events);
2070 }
2071
2072 if let Ok(initial_usage) = initial_token_usage {
2073 let usage = thread.cumulative_token_usage - initial_usage;
2074
2075 telemetry::event!(
2076 "Assistant Thread Completion",
2077 thread_id = thread.id().to_string(),
2078 prompt_id = prompt_id,
2079 model = model.telemetry_id(),
2080 model_provider = model.provider_id().to_string(),
2081 input_tokens = usage.input_tokens,
2082 output_tokens = usage.output_tokens,
2083 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2084 cache_read_input_tokens = usage.cache_read_input_tokens,
2085 );
2086 }
2087 })
2088 .ok();
2089 });
2090
2091 self.pending_completions.push(PendingCompletion {
2092 id: pending_completion_id,
2093 queue_state: QueueState::Sending,
2094 _task: task,
2095 });
2096 }
2097
2098 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2099 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2100 println!("No thread summary model");
2101 return;
2102 };
2103
2104 if !model.provider.is_authenticated(cx) {
2105 return;
2106 }
2107
2108 let request = self.to_summarize_request(
2109 &model.model,
2110 CompletionIntent::ThreadSummarization,
2111 SUMMARIZE_THREAD_PROMPT.into(),
2112 cx,
2113 );
2114
2115 self.summary = ThreadSummary::Generating;
2116
2117 self.pending_summary = cx.spawn(async move |this, cx| {
2118 let result = async {
2119 let mut messages = model.model.stream_completion(request, cx).await?;
2120
2121 let mut new_summary = String::new();
2122 while let Some(event) = messages.next().await {
2123 let Ok(event) = event else {
2124 continue;
2125 };
2126 let text = match event {
2127 LanguageModelCompletionEvent::Text(text) => text,
2128 LanguageModelCompletionEvent::StatusUpdate(
2129 CompletionRequestStatus::UsageUpdated { amount, limit },
2130 ) => {
2131 this.update(cx, |thread, cx| {
2132 thread.update_model_request_usage(amount as u32, limit, cx);
2133 })?;
2134 continue;
2135 }
2136 _ => continue,
2137 };
2138
2139 let mut lines = text.lines();
2140 new_summary.extend(lines.next());
2141
2142 // Stop if the LLM generated multiple lines.
2143 if lines.next().is_some() {
2144 break;
2145 }
2146 }
2147
2148 anyhow::Ok(new_summary)
2149 }
2150 .await;
2151
2152 this.update(cx, |this, cx| {
2153 match result {
2154 Ok(new_summary) => {
2155 if new_summary.is_empty() {
2156 this.summary = ThreadSummary::Error;
2157 } else {
2158 this.summary = ThreadSummary::Ready(new_summary.into());
2159 }
2160 }
2161 Err(err) => {
2162 this.summary = ThreadSummary::Error;
2163 log::error!("Failed to generate thread summary: {}", err);
2164 }
2165 }
2166 cx.emit(ThreadEvent::SummaryGenerated);
2167 })
2168 .log_err()?;
2169
2170 Some(())
2171 });
2172 }
2173
2174 fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2175 use LanguageModelCompletionError::*;
2176
2177 // General strategy here:
2178 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2179 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2180 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2181 match error {
2182 HttpResponseError {
2183 status_code: StatusCode::TOO_MANY_REQUESTS,
2184 ..
2185 } => Some(RetryStrategy::ExponentialBackoff {
2186 initial_delay: BASE_RETRY_DELAY,
2187 max_attempts: MAX_RETRY_ATTEMPTS,
2188 }),
2189 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2190 Some(RetryStrategy::Fixed {
2191 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2192 max_attempts: MAX_RETRY_ATTEMPTS,
2193 })
2194 }
2195 UpstreamProviderError {
2196 status,
2197 retry_after,
2198 ..
2199 } => match *status {
2200 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2201 Some(RetryStrategy::Fixed {
2202 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2203 max_attempts: MAX_RETRY_ATTEMPTS,
2204 })
2205 }
2206 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2207 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2208 // Internal Server Error could be anything, retry up to 3 times.
2209 max_attempts: 3,
2210 }),
2211 status => {
2212 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2213 // but we frequently get them in practice. See https://http.dev/529
2214 if status.as_u16() == 529 {
2215 Some(RetryStrategy::Fixed {
2216 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2217 max_attempts: MAX_RETRY_ATTEMPTS,
2218 })
2219 } else {
2220 Some(RetryStrategy::Fixed {
2221 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2222 max_attempts: 2,
2223 })
2224 }
2225 }
2226 },
2227 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2228 delay: BASE_RETRY_DELAY,
2229 max_attempts: 3,
2230 }),
2231 ApiReadResponseError { .. }
2232 | HttpSend { .. }
2233 | DeserializeResponse { .. }
2234 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2235 delay: BASE_RETRY_DELAY,
2236 max_attempts: 3,
2237 }),
2238 // Retrying these errors definitely shouldn't help.
2239 HttpResponseError {
2240 status_code:
2241 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2242 ..
2243 }
2244 | AuthenticationError { .. }
2245 | PermissionError { .. }
2246 | NoApiKey { .. }
2247 | ApiEndpointNotFound { .. }
2248 | PromptTooLarge { .. } => None,
2249 // These errors might be transient, so retry them
2250 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2251 delay: BASE_RETRY_DELAY,
2252 max_attempts: 1,
2253 }),
2254 // Retry all other 4xx and 5xx errors once.
2255 HttpResponseError { status_code, .. }
2256 if status_code.is_client_error() || status_code.is_server_error() =>
2257 {
2258 Some(RetryStrategy::Fixed {
2259 delay: BASE_RETRY_DELAY,
2260 max_attempts: 3,
2261 })
2262 }
2263 Other(err)
2264 if err.is::<PaymentRequiredError>()
2265 || err.is::<ModelRequestLimitReachedError>() =>
2266 {
2267 // Retrying won't help for Payment Required or Model Request Limit errors (where
2268 // the user must upgrade to usage-based billing to get more requests, or else wait
2269 // for a significant amount of time for the request limit to reset).
2270 None
2271 }
2272 // Conservatively assume that any other errors are non-retryable
2273 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2274 delay: BASE_RETRY_DELAY,
2275 max_attempts: 2,
2276 }),
2277 }
2278 }
2279
2280 fn handle_retryable_error_with_delay(
2281 &mut self,
2282 error: &LanguageModelCompletionError,
2283 strategy: Option<RetryStrategy>,
2284 model: Arc<dyn LanguageModel>,
2285 intent: CompletionIntent,
2286 window: Option<AnyWindowHandle>,
2287 cx: &mut Context<Self>,
2288 ) -> bool {
2289 // Store context for the Retry button
2290 self.last_error_context = Some((model.clone(), intent));
2291
2292 // Only auto-retry if Burn Mode is enabled
2293 if self.completion_mode != CompletionMode::Burn {
2294 // Show error with retry options
2295 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2296 message: format!(
2297 "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
2298 error
2299 )
2300 .into(),
2301 can_enable_burn_mode: true,
2302 }));
2303 return false;
2304 }
2305
2306 let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2307 return false;
2308 };
2309
2310 let max_attempts = match &strategy {
2311 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2312 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2313 };
2314
2315 let retry_state = self.retry_state.get_or_insert(RetryState {
2316 attempt: 0,
2317 max_attempts,
2318 intent,
2319 });
2320
2321 retry_state.attempt += 1;
2322 let attempt = retry_state.attempt;
2323 let max_attempts = retry_state.max_attempts;
2324 let intent = retry_state.intent;
2325
2326 if attempt <= max_attempts {
2327 let delay = match &strategy {
2328 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2329 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2330 Duration::from_secs(delay_secs)
2331 }
2332 RetryStrategy::Fixed { delay, .. } => *delay,
2333 };
2334
2335 // Add a transient message to inform the user
2336 let delay_secs = delay.as_secs();
2337 let retry_message = if max_attempts == 1 {
2338 format!("{error}. Retrying in {delay_secs} seconds...")
2339 } else {
2340 format!(
2341 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2342 in {delay_secs} seconds..."
2343 )
2344 };
2345 log::warn!(
2346 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2347 in {delay_secs} seconds: {error:?}",
2348 );
2349
2350 // Add a UI-only message instead of a regular message
2351 let id = self.next_message_id.post_inc();
2352 self.messages.push(Message {
2353 id,
2354 role: Role::System,
2355 segments: vec![MessageSegment::Text(retry_message)],
2356 loaded_context: LoadedContext::default(),
2357 creases: Vec::new(),
2358 is_hidden: false,
2359 ui_only: true,
2360 });
2361 cx.emit(ThreadEvent::MessageAdded(id));
2362
2363 // Schedule the retry
2364 let thread_handle = cx.entity().downgrade();
2365
2366 cx.spawn(async move |_thread, cx| {
2367 cx.background_executor().timer(delay).await;
2368
2369 thread_handle
2370 .update(cx, |thread, cx| {
2371 // Retry the completion
2372 thread.send_to_model(model, intent, window, cx);
2373 })
2374 .log_err();
2375 })
2376 .detach();
2377
2378 true
2379 } else {
2380 // Max retries exceeded
2381 self.retry_state = None;
2382
2383 // Stop generating since we're giving up on retrying.
2384 self.pending_completions.clear();
2385
2386 // Show error alongside a Retry button, but no
2387 // Enable Burn Mode button (since it's already enabled)
2388 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2389 message: format!("Failed after retrying: {}", error).into(),
2390 can_enable_burn_mode: false,
2391 }));
2392
2393 false
2394 }
2395 }
2396
2397 pub fn start_generating_detailed_summary_if_needed(
2398 &mut self,
2399 thread_store: WeakEntity<ThreadStore>,
2400 cx: &mut Context<Self>,
2401 ) {
2402 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2403 return;
2404 };
2405
2406 match &*self.detailed_summary_rx.borrow() {
2407 DetailedSummaryState::Generating { message_id, .. }
2408 | DetailedSummaryState::Generated { message_id, .. }
2409 if *message_id == last_message_id =>
2410 {
2411 // Already up-to-date
2412 return;
2413 }
2414 _ => {}
2415 }
2416
2417 let Some(ConfiguredModel { model, provider }) =
2418 LanguageModelRegistry::read_global(cx).thread_summary_model()
2419 else {
2420 return;
2421 };
2422
2423 if !provider.is_authenticated(cx) {
2424 return;
2425 }
2426
2427 let request = self.to_summarize_request(
2428 &model,
2429 CompletionIntent::ThreadContextSummarization,
2430 SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
2431 cx,
2432 );
2433
2434 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2435 message_id: last_message_id,
2436 };
2437
2438 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2439 // be better to allow the old task to complete, but this would require logic for choosing
2440 // which result to prefer (the old task could complete after the new one, resulting in a
2441 // stale summary).
2442 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2443 let stream = model.stream_completion_text(request, cx);
2444 let Some(mut messages) = stream.await.log_err() else {
2445 thread
2446 .update(cx, |thread, _cx| {
2447 *thread.detailed_summary_tx.borrow_mut() =
2448 DetailedSummaryState::NotGenerated;
2449 })
2450 .ok()?;
2451 return None;
2452 };
2453
2454 let mut new_detailed_summary = String::new();
2455
2456 while let Some(chunk) = messages.stream.next().await {
2457 if let Some(chunk) = chunk.log_err() {
2458 new_detailed_summary.push_str(&chunk);
2459 }
2460 }
2461
2462 thread
2463 .update(cx, |thread, _cx| {
2464 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2465 text: new_detailed_summary.into(),
2466 message_id: last_message_id,
2467 };
2468 })
2469 .ok()?;
2470
2471 // Save thread so its summary can be reused later
2472 if let Some(thread) = thread.upgrade()
2473 && let Ok(Ok(save_task)) = cx.update(|cx| {
2474 thread_store
2475 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2476 })
2477 {
2478 save_task.await.log_err();
2479 }
2480
2481 Some(())
2482 });
2483 }
2484
2485 pub async fn wait_for_detailed_summary_or_text(
2486 this: &Entity<Self>,
2487 cx: &mut AsyncApp,
2488 ) -> Option<SharedString> {
2489 let mut detailed_summary_rx = this
2490 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2491 .ok()?;
2492 loop {
2493 match detailed_summary_rx.recv().await? {
2494 DetailedSummaryState::Generating { .. } => {}
2495 DetailedSummaryState::NotGenerated => {
2496 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2497 }
2498 DetailedSummaryState::Generated { text, .. } => return Some(text),
2499 }
2500 }
2501 }
2502
2503 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2504 self.detailed_summary_rx
2505 .borrow()
2506 .text()
2507 .unwrap_or_else(|| self.text().into())
2508 }
2509
2510 pub fn is_generating_detailed_summary(&self) -> bool {
2511 matches!(
2512 &*self.detailed_summary_rx.borrow(),
2513 DetailedSummaryState::Generating { .. }
2514 )
2515 }
2516
2517 pub fn use_pending_tools(
2518 &mut self,
2519 window: Option<AnyWindowHandle>,
2520 model: Arc<dyn LanguageModel>,
2521 cx: &mut Context<Self>,
2522 ) -> Vec<PendingToolUse> {
2523 let request =
2524 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2525 let pending_tool_uses = self
2526 .tool_use
2527 .pending_tool_uses()
2528 .into_iter()
2529 .filter(|tool_use| tool_use.status.is_idle())
2530 .cloned()
2531 .collect::<Vec<_>>();
2532
2533 for tool_use in pending_tool_uses.iter() {
2534 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2535 }
2536
2537 pending_tool_uses
2538 }
2539
2540 fn use_pending_tool(
2541 &mut self,
2542 tool_use: PendingToolUse,
2543 request: Arc<LanguageModelRequest>,
2544 model: Arc<dyn LanguageModel>,
2545 window: Option<AnyWindowHandle>,
2546 cx: &mut Context<Self>,
2547 ) {
2548 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2549 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2550 };
2551
2552 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2553 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2554 }
2555
2556 if tool.needs_confirmation(&tool_use.input, &self.project, cx)
2557 && !AgentSettings::get_global(cx).always_allow_tool_actions
2558 {
2559 self.tool_use.confirm_tool_use(
2560 tool_use.id,
2561 tool_use.ui_text,
2562 tool_use.input,
2563 request,
2564 tool,
2565 );
2566 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2567 } else {
2568 self.run_tool(
2569 tool_use.id,
2570 tool_use.ui_text,
2571 tool_use.input,
2572 request,
2573 tool,
2574 model,
2575 window,
2576 cx,
2577 );
2578 }
2579 }
2580
2581 pub fn handle_hallucinated_tool_use(
2582 &mut self,
2583 tool_use_id: LanguageModelToolUseId,
2584 hallucinated_tool_name: Arc<str>,
2585 window: Option<AnyWindowHandle>,
2586 cx: &mut Context<Thread>,
2587 ) {
2588 let available_tools = self.profile.enabled_tools(cx);
2589
2590 let tool_list = available_tools
2591 .iter()
2592 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2593 .collect::<Vec<_>>()
2594 .join("\n");
2595
2596 let error_message = format!(
2597 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2598 hallucinated_tool_name, tool_list
2599 );
2600
2601 let pending_tool_use = self.tool_use.insert_tool_output(
2602 tool_use_id.clone(),
2603 hallucinated_tool_name,
2604 Err(anyhow!("Missing tool call: {error_message}")),
2605 self.configured_model.as_ref(),
2606 self.completion_mode,
2607 );
2608
2609 cx.emit(ThreadEvent::MissingToolUse {
2610 tool_use_id: tool_use_id.clone(),
2611 ui_text: error_message.into(),
2612 });
2613
2614 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2615 }
2616
2617 pub fn receive_invalid_tool_json(
2618 &mut self,
2619 tool_use_id: LanguageModelToolUseId,
2620 tool_name: Arc<str>,
2621 invalid_json: Arc<str>,
2622 error: String,
2623 window: Option<AnyWindowHandle>,
2624 cx: &mut Context<Thread>,
2625 ) {
2626 log::error!("The model returned invalid input JSON: {invalid_json}");
2627
2628 let pending_tool_use = self.tool_use.insert_tool_output(
2629 tool_use_id.clone(),
2630 tool_name,
2631 Err(anyhow!("Error parsing input JSON: {error}")),
2632 self.configured_model.as_ref(),
2633 self.completion_mode,
2634 );
2635 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2636 pending_tool_use.ui_text.clone()
2637 } else {
2638 log::error!(
2639 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2640 );
2641 format!("Unknown tool {}", tool_use_id).into()
2642 };
2643
2644 cx.emit(ThreadEvent::InvalidToolInput {
2645 tool_use_id: tool_use_id.clone(),
2646 ui_text,
2647 invalid_input_json: invalid_json,
2648 });
2649
2650 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2651 }
2652
2653 pub fn run_tool(
2654 &mut self,
2655 tool_use_id: LanguageModelToolUseId,
2656 ui_text: impl Into<SharedString>,
2657 input: serde_json::Value,
2658 request: Arc<LanguageModelRequest>,
2659 tool: Arc<dyn Tool>,
2660 model: Arc<dyn LanguageModel>,
2661 window: Option<AnyWindowHandle>,
2662 cx: &mut Context<Thread>,
2663 ) {
2664 let task =
2665 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2666 self.tool_use
2667 .run_pending_tool(tool_use_id, ui_text.into(), task);
2668 }
2669
2670 fn spawn_tool_use(
2671 &mut self,
2672 tool_use_id: LanguageModelToolUseId,
2673 request: Arc<LanguageModelRequest>,
2674 input: serde_json::Value,
2675 tool: Arc<dyn Tool>,
2676 model: Arc<dyn LanguageModel>,
2677 window: Option<AnyWindowHandle>,
2678 cx: &mut Context<Thread>,
2679 ) -> Task<()> {
2680 let tool_name: Arc<str> = tool.name().into();
2681
2682 let tool_result = tool.run(
2683 input,
2684 request,
2685 self.project.clone(),
2686 self.action_log.clone(),
2687 model,
2688 window,
2689 cx,
2690 );
2691
2692 // Store the card separately if it exists
2693 if let Some(card) = tool_result.card.clone() {
2694 self.tool_use
2695 .insert_tool_result_card(tool_use_id.clone(), card);
2696 }
2697
2698 cx.spawn({
2699 async move |thread: WeakEntity<Thread>, cx| {
2700 let output = tool_result.output.await;
2701
2702 thread
2703 .update(cx, |thread, cx| {
2704 let pending_tool_use = thread.tool_use.insert_tool_output(
2705 tool_use_id.clone(),
2706 tool_name,
2707 output,
2708 thread.configured_model.as_ref(),
2709 thread.completion_mode,
2710 );
2711 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2712 })
2713 .ok();
2714 }
2715 })
2716 }
2717
2718 fn tool_finished(
2719 &mut self,
2720 tool_use_id: LanguageModelToolUseId,
2721 pending_tool_use: Option<PendingToolUse>,
2722 canceled: bool,
2723 window: Option<AnyWindowHandle>,
2724 cx: &mut Context<Self>,
2725 ) {
2726 if self.all_tools_finished()
2727 && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
2728 && !canceled
2729 {
2730 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2731 }
2732
2733 cx.emit(ThreadEvent::ToolFinished {
2734 tool_use_id,
2735 pending_tool_use,
2736 });
2737 }
2738
2739 /// Cancels the last pending completion, if there are any pending.
2740 ///
2741 /// Returns whether a completion was canceled.
2742 pub fn cancel_last_completion(
2743 &mut self,
2744 window: Option<AnyWindowHandle>,
2745 cx: &mut Context<Self>,
2746 ) -> bool {
2747 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2748
2749 self.retry_state = None;
2750
2751 for pending_tool_use in self.tool_use.cancel_pending() {
2752 canceled = true;
2753 self.tool_finished(
2754 pending_tool_use.id.clone(),
2755 Some(pending_tool_use),
2756 true,
2757 window,
2758 cx,
2759 );
2760 }
2761
2762 if canceled {
2763 cx.emit(ThreadEvent::CompletionCanceled);
2764
2765 // When canceled, we always want to insert the checkpoint.
2766 // (We skip over finalize_pending_checkpoint, because it
2767 // would conclude we didn't have anything to insert here.)
2768 if let Some(checkpoint) = self.pending_checkpoint.take() {
2769 self.insert_checkpoint(checkpoint, cx);
2770 }
2771 } else {
2772 self.finalize_pending_checkpoint(cx);
2773 }
2774
2775 canceled
2776 }
2777
2778 /// Signals that any in-progress editing should be canceled.
2779 ///
2780 /// This method is used to notify listeners (like ActiveThread) that
2781 /// they should cancel any editing operations.
2782 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2783 cx.emit(ThreadEvent::CancelEditing);
2784 }
2785
2786 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2787 self.message_feedback.get(&message_id).copied()
2788 }
2789
2790 pub fn report_message_feedback(
2791 &mut self,
2792 message_id: MessageId,
2793 feedback: ThreadFeedback,
2794 cx: &mut Context<Self>,
2795 ) -> Task<Result<()>> {
2796 if self.message_feedback.get(&message_id) == Some(&feedback) {
2797 return Task::ready(Ok(()));
2798 }
2799
2800 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2801 let serialized_thread = self.serialize(cx);
2802 let thread_id = self.id().clone();
2803 let client = self.project.read(cx).client();
2804
2805 let enabled_tool_names: Vec<String> = self
2806 .profile
2807 .enabled_tools(cx)
2808 .iter()
2809 .map(|(name, _)| name.clone().into())
2810 .collect();
2811
2812 self.message_feedback.insert(message_id, feedback);
2813
2814 cx.notify();
2815
2816 let message_content = self
2817 .message(message_id)
2818 .map(|msg| msg.to_message_content())
2819 .unwrap_or_default();
2820
2821 cx.background_spawn(async move {
2822 let final_project_snapshot = final_project_snapshot.await;
2823 let serialized_thread = serialized_thread.await?;
2824 let thread_data =
2825 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2826
2827 let rating = match feedback {
2828 ThreadFeedback::Positive => "positive",
2829 ThreadFeedback::Negative => "negative",
2830 };
2831 telemetry::event!(
2832 "Assistant Thread Rated",
2833 rating,
2834 thread_id,
2835 enabled_tool_names,
2836 message_id = message_id.0,
2837 message_content,
2838 thread_data,
2839 final_project_snapshot
2840 );
2841 client.telemetry().flush_events().await;
2842
2843 Ok(())
2844 })
2845 }
2846
2847 /// Create a snapshot of the current project state including git information and unsaved buffers.
2848 fn project_snapshot(
2849 project: Entity<Project>,
2850 cx: &mut Context<Self>,
2851 ) -> Task<Arc<ProjectSnapshot>> {
2852 let git_store = project.read(cx).git_store().clone();
2853 let worktree_snapshots: Vec<_> = project
2854 .read(cx)
2855 .visible_worktrees(cx)
2856 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2857 .collect();
2858
2859 cx.spawn(async move |_, _| {
2860 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2861
2862 Arc::new(ProjectSnapshot {
2863 worktree_snapshots,
2864 timestamp: Utc::now(),
2865 })
2866 })
2867 }
2868
2869 fn worktree_snapshot(
2870 worktree: Entity<project::Worktree>,
2871 git_store: Entity<GitStore>,
2872 cx: &App,
2873 ) -> Task<WorktreeSnapshot> {
2874 cx.spawn(async move |cx| {
2875 // Get worktree path and snapshot
2876 let worktree_info = cx.update(|app_cx| {
2877 let worktree = worktree.read(app_cx);
2878 let path = worktree.abs_path().to_string_lossy().into_owned();
2879 let snapshot = worktree.snapshot();
2880 (path, snapshot)
2881 });
2882
2883 let Ok((worktree_path, _snapshot)) = worktree_info else {
2884 return WorktreeSnapshot {
2885 worktree_path: String::new(),
2886 git_state: None,
2887 };
2888 };
2889
2890 let git_state = git_store
2891 .update(cx, |git_store, cx| {
2892 git_store
2893 .repositories()
2894 .values()
2895 .find(|repo| {
2896 repo.read(cx)
2897 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2898 .is_some()
2899 })
2900 .cloned()
2901 })
2902 .ok()
2903 .flatten()
2904 .map(|repo| {
2905 repo.update(cx, |repo, _| {
2906 let current_branch =
2907 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2908 repo.send_job(None, |state, _| async move {
2909 let RepositoryState::Local { backend, .. } = state else {
2910 return GitState {
2911 remote_url: None,
2912 head_sha: None,
2913 current_branch,
2914 diff: None,
2915 };
2916 };
2917
2918 let remote_url = backend.remote_url("origin");
2919 let head_sha = backend.head_sha().await;
2920 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2921
2922 GitState {
2923 remote_url,
2924 head_sha,
2925 current_branch,
2926 diff,
2927 }
2928 })
2929 })
2930 });
2931
2932 let git_state = match git_state {
2933 Some(git_state) => match git_state.ok() {
2934 Some(git_state) => git_state.await.ok(),
2935 None => None,
2936 },
2937 None => None,
2938 };
2939
2940 WorktreeSnapshot {
2941 worktree_path,
2942 git_state,
2943 }
2944 })
2945 }
2946
2947 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2948 let mut markdown = Vec::new();
2949
2950 let summary = self.summary().or_default();
2951 writeln!(markdown, "# {summary}\n")?;
2952
2953 for message in self.messages() {
2954 writeln!(
2955 markdown,
2956 "## {role}\n",
2957 role = match message.role {
2958 Role::User => "User",
2959 Role::Assistant => "Agent",
2960 Role::System => "System",
2961 }
2962 )?;
2963
2964 if !message.loaded_context.text.is_empty() {
2965 writeln!(markdown, "{}", message.loaded_context.text)?;
2966 }
2967
2968 if !message.loaded_context.images.is_empty() {
2969 writeln!(
2970 markdown,
2971 "\n{} images attached as context.\n",
2972 message.loaded_context.images.len()
2973 )?;
2974 }
2975
2976 for segment in &message.segments {
2977 match segment {
2978 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2979 MessageSegment::Thinking { text, .. } => {
2980 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2981 }
2982 MessageSegment::RedactedThinking(_) => {}
2983 }
2984 }
2985
2986 for tool_use in self.tool_uses_for_message(message.id, cx) {
2987 writeln!(
2988 markdown,
2989 "**Use Tool: {} ({})**",
2990 tool_use.name, tool_use.id
2991 )?;
2992 writeln!(markdown, "```json")?;
2993 writeln!(
2994 markdown,
2995 "{}",
2996 serde_json::to_string_pretty(&tool_use.input)?
2997 )?;
2998 writeln!(markdown, "```")?;
2999 }
3000
3001 for tool_result in self.tool_results_for_message(message.id) {
3002 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
3003 if tool_result.is_error {
3004 write!(markdown, " (Error)")?;
3005 }
3006
3007 writeln!(markdown, "**\n")?;
3008 match &tool_result.content {
3009 LanguageModelToolResultContent::Text(text) => {
3010 writeln!(markdown, "{text}")?;
3011 }
3012 LanguageModelToolResultContent::Image(image) => {
3013 writeln!(markdown, "", image.source)?;
3014 }
3015 }
3016
3017 if let Some(output) = tool_result.output.as_ref() {
3018 writeln!(
3019 markdown,
3020 "\n\nDebug Output:\n\n```json\n{}\n```\n",
3021 serde_json::to_string_pretty(output)?
3022 )?;
3023 }
3024 }
3025 }
3026
3027 Ok(String::from_utf8_lossy(&markdown).to_string())
3028 }
3029
3030 pub fn keep_edits_in_range(
3031 &mut self,
3032 buffer: Entity<language::Buffer>,
3033 buffer_range: Range<language::Anchor>,
3034 cx: &mut Context<Self>,
3035 ) {
3036 self.action_log.update(cx, |action_log, cx| {
3037 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3038 });
3039 }
3040
3041 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3042 self.action_log
3043 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3044 }
3045
3046 pub fn reject_edits_in_ranges(
3047 &mut self,
3048 buffer: Entity<language::Buffer>,
3049 buffer_ranges: Vec<Range<language::Anchor>>,
3050 cx: &mut Context<Self>,
3051 ) -> Task<Result<()>> {
3052 self.action_log.update(cx, |action_log, cx| {
3053 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3054 })
3055 }
3056
3057 pub fn action_log(&self) -> &Entity<ActionLog> {
3058 &self.action_log
3059 }
3060
3061 pub fn project(&self) -> &Entity<Project> {
3062 &self.project
3063 }
3064
3065 pub fn cumulative_token_usage(&self) -> TokenUsage {
3066 self.cumulative_token_usage
3067 }
3068
3069 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3070 let Some(model) = self.configured_model.as_ref() else {
3071 return TotalTokenUsage::default();
3072 };
3073
3074 let max = model
3075 .model
3076 .max_token_count_for_mode(self.completion_mode().into());
3077
3078 let index = self
3079 .messages
3080 .iter()
3081 .position(|msg| msg.id == message_id)
3082 .unwrap_or(0);
3083
3084 if index == 0 {
3085 return TotalTokenUsage { total: 0, max };
3086 }
3087
3088 let token_usage = &self
3089 .request_token_usage
3090 .get(index - 1)
3091 .cloned()
3092 .unwrap_or_default();
3093
3094 TotalTokenUsage {
3095 total: token_usage.total_tokens(),
3096 max,
3097 }
3098 }
3099
3100 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3101 let model = self.configured_model.as_ref()?;
3102
3103 let max = model
3104 .model
3105 .max_token_count_for_mode(self.completion_mode().into());
3106
3107 if let Some(exceeded_error) = &self.exceeded_window_error
3108 && model.model.id() == exceeded_error.model_id
3109 {
3110 return Some(TotalTokenUsage {
3111 total: exceeded_error.token_count,
3112 max,
3113 });
3114 }
3115
3116 let total = self
3117 .token_usage_at_last_message()
3118 .unwrap_or_default()
3119 .total_tokens();
3120
3121 Some(TotalTokenUsage { total, max })
3122 }
3123
3124 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3125 self.request_token_usage
3126 .get(self.messages.len().saturating_sub(1))
3127 .or_else(|| self.request_token_usage.last())
3128 .cloned()
3129 }
3130
3131 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3132 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3133 self.request_token_usage
3134 .resize(self.messages.len(), placeholder);
3135
3136 if let Some(last) = self.request_token_usage.last_mut() {
3137 *last = token_usage;
3138 }
3139 }
3140
3141 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3142 self.project
3143 .read(cx)
3144 .user_store()
3145 .update(cx, |user_store, cx| {
3146 user_store.update_model_request_usage(
3147 ModelRequestUsage(RequestUsage {
3148 amount: amount as i32,
3149 limit,
3150 }),
3151 cx,
3152 )
3153 });
3154 }
3155
3156 pub fn deny_tool_use(
3157 &mut self,
3158 tool_use_id: LanguageModelToolUseId,
3159 tool_name: Arc<str>,
3160 window: Option<AnyWindowHandle>,
3161 cx: &mut Context<Self>,
3162 ) {
3163 let err = Err(anyhow::anyhow!(
3164 "Permission to run tool action denied by user"
3165 ));
3166
3167 self.tool_use.insert_tool_output(
3168 tool_use_id.clone(),
3169 tool_name,
3170 err,
3171 self.configured_model.as_ref(),
3172 self.completion_mode,
3173 );
3174 self.tool_finished(tool_use_id, None, true, window, cx);
3175 }
3176}
3177
3178#[derive(Debug, Clone, Error)]
3179pub enum ThreadError {
3180 #[error("Payment required")]
3181 PaymentRequired,
3182 #[error("Model request limit reached")]
3183 ModelRequestLimitReached { plan: Plan },
3184 #[error("Message {header}: {message}")]
3185 Message {
3186 header: SharedString,
3187 message: SharedString,
3188 },
3189 #[error("Retryable error: {message}")]
3190 RetryableError {
3191 message: SharedString,
3192 can_enable_burn_mode: bool,
3193 },
3194}
3195
3196#[derive(Debug, Clone)]
3197pub enum ThreadEvent {
3198 ShowError(ThreadError),
3199 StreamedCompletion,
3200 ReceivedTextChunk,
3201 NewRequest,
3202 StreamedAssistantText(MessageId, String),
3203 StreamedAssistantThinking(MessageId, String),
3204 StreamedToolUse {
3205 tool_use_id: LanguageModelToolUseId,
3206 ui_text: Arc<str>,
3207 input: serde_json::Value,
3208 },
3209 MissingToolUse {
3210 tool_use_id: LanguageModelToolUseId,
3211 ui_text: Arc<str>,
3212 },
3213 InvalidToolInput {
3214 tool_use_id: LanguageModelToolUseId,
3215 ui_text: Arc<str>,
3216 invalid_input_json: Arc<str>,
3217 },
3218 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3219 MessageAdded(MessageId),
3220 MessageEdited(MessageId),
3221 MessageDeleted(MessageId),
3222 SummaryGenerated,
3223 SummaryChanged,
3224 UsePendingTools {
3225 tool_uses: Vec<PendingToolUse>,
3226 },
3227 ToolFinished {
3228 #[allow(unused)]
3229 tool_use_id: LanguageModelToolUseId,
3230 /// The pending tool use that corresponds to this tool.
3231 pending_tool_use: Option<PendingToolUse>,
3232 },
3233 CheckpointChanged,
3234 ToolConfirmationNeeded,
3235 ToolUseLimitReached,
3236 CancelEditing,
3237 CompletionCanceled,
3238 ProfileChanged,
3239}
3240
3241impl EventEmitter<ThreadEvent> for Thread {}
3242
3243struct PendingCompletion {
3244 id: usize,
3245 queue_state: QueueState,
3246 _task: Task<()>,
3247}
3248
3249#[cfg(test)]
3250mod tests {
3251 use super::*;
3252 use crate::{
3253 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3254 };
3255
3256 // Test-specific constants
3257 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3258 use agent_settings::{AgentProfileId, AgentSettings};
3259 use assistant_tool::ToolRegistry;
3260 use assistant_tools;
3261 use fs::Fs;
3262 use futures::StreamExt;
3263 use futures::future::BoxFuture;
3264 use futures::stream::BoxStream;
3265 use gpui::TestAppContext;
3266 use http_client;
3267 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3268 use language_model::{
3269 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3270 LanguageModelProviderName, LanguageModelToolChoice,
3271 };
3272 use parking_lot::Mutex;
3273 use project::{FakeFs, Project};
3274 use prompt_store::PromptBuilder;
3275 use serde_json::json;
3276 use settings::{LanguageModelParameters, Settings, SettingsStore};
3277 use std::sync::Arc;
3278 use std::time::Duration;
3279 use util::path;
3280 use workspace::Workspace;
3281
3282 #[gpui::test]
3283 async fn test_message_with_context(cx: &mut TestAppContext) {
3284 let fs = init_test_settings(cx);
3285
3286 let project = create_test_project(
3287 &fs,
3288 cx,
3289 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3290 )
3291 .await;
3292
3293 let (_workspace, _thread_store, thread, context_store, model) =
3294 setup_test_environment(cx, project.clone()).await;
3295
3296 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3297 .await
3298 .unwrap();
3299
3300 let context =
3301 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3302 let loaded_context = cx
3303 .update(|cx| load_context(vec![context], &project, &None, cx))
3304 .await;
3305
3306 // Insert user message with context
3307 let message_id = thread.update(cx, |thread, cx| {
3308 thread.insert_user_message(
3309 "Please explain this code",
3310 loaded_context,
3311 None,
3312 Vec::new(),
3313 cx,
3314 )
3315 });
3316
3317 // Check content and context in message object
3318 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3319
3320 // Use different path format strings based on platform for the test
3321 #[cfg(windows)]
3322 let path_part = r"test\code.rs";
3323 #[cfg(not(windows))]
3324 let path_part = "test/code.rs";
3325
3326 let expected_context = format!(
3327 r#"
3328<context>
3329The following items were attached by the user. They are up-to-date and don't need to be re-read.
3330
3331<files>
3332```rs {path_part}
3333fn main() {{
3334 println!("Hello, world!");
3335}}
3336```
3337</files>
3338</context>
3339"#
3340 );
3341
3342 assert_eq!(message.role, Role::User);
3343 assert_eq!(message.segments.len(), 1);
3344 assert_eq!(
3345 message.segments[0],
3346 MessageSegment::Text("Please explain this code".to_string())
3347 );
3348 assert_eq!(message.loaded_context.text, expected_context);
3349
3350 // Check message in request
3351 let request = thread.update(cx, |thread, cx| {
3352 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3353 });
3354
3355 assert_eq!(request.messages.len(), 2);
3356 let expected_full_message = format!("{}Please explain this code", expected_context);
3357 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3358 }
3359
3360 #[gpui::test]
3361 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3362 let fs = init_test_settings(cx);
3363
3364 let project = create_test_project(
3365 &fs,
3366 cx,
3367 json!({
3368 "file1.rs": "fn function1() {}\n",
3369 "file2.rs": "fn function2() {}\n",
3370 "file3.rs": "fn function3() {}\n",
3371 "file4.rs": "fn function4() {}\n",
3372 }),
3373 )
3374 .await;
3375
3376 let (_, _thread_store, thread, context_store, model) =
3377 setup_test_environment(cx, project.clone()).await;
3378
3379 // First message with context 1
3380 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3381 .await
3382 .unwrap();
3383 let new_contexts = context_store.update(cx, |store, cx| {
3384 store.new_context_for_thread(thread.read(cx), None)
3385 });
3386 assert_eq!(new_contexts.len(), 1);
3387 let loaded_context = cx
3388 .update(|cx| load_context(new_contexts, &project, &None, cx))
3389 .await;
3390 let message1_id = thread.update(cx, |thread, cx| {
3391 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3392 });
3393
3394 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3395 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3396 .await
3397 .unwrap();
3398 let new_contexts = context_store.update(cx, |store, cx| {
3399 store.new_context_for_thread(thread.read(cx), None)
3400 });
3401 assert_eq!(new_contexts.len(), 1);
3402 let loaded_context = cx
3403 .update(|cx| load_context(new_contexts, &project, &None, cx))
3404 .await;
3405 let message2_id = thread.update(cx, |thread, cx| {
3406 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3407 });
3408
3409 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3410 //
3411 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3412 .await
3413 .unwrap();
3414 let new_contexts = context_store.update(cx, |store, cx| {
3415 store.new_context_for_thread(thread.read(cx), None)
3416 });
3417 assert_eq!(new_contexts.len(), 1);
3418 let loaded_context = cx
3419 .update(|cx| load_context(new_contexts, &project, &None, cx))
3420 .await;
3421 let message3_id = thread.update(cx, |thread, cx| {
3422 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3423 });
3424
3425 // Check what contexts are included in each message
3426 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3427 (
3428 thread.message(message1_id).unwrap().clone(),
3429 thread.message(message2_id).unwrap().clone(),
3430 thread.message(message3_id).unwrap().clone(),
3431 )
3432 });
3433
3434 // First message should include context 1
3435 assert!(message1.loaded_context.text.contains("file1.rs"));
3436
3437 // Second message should include only context 2 (not 1)
3438 assert!(!message2.loaded_context.text.contains("file1.rs"));
3439 assert!(message2.loaded_context.text.contains("file2.rs"));
3440
3441 // Third message should include only context 3 (not 1 or 2)
3442 assert!(!message3.loaded_context.text.contains("file1.rs"));
3443 assert!(!message3.loaded_context.text.contains("file2.rs"));
3444 assert!(message3.loaded_context.text.contains("file3.rs"));
3445
3446 // Check entire request to make sure all contexts are properly included
3447 let request = thread.update(cx, |thread, cx| {
3448 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3449 });
3450
3451 // The request should contain all 3 messages
3452 assert_eq!(request.messages.len(), 4);
3453
3454 // Check that the contexts are properly formatted in each message
3455 assert!(request.messages[1].string_contents().contains("file1.rs"));
3456 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3457 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3458
3459 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3460 assert!(request.messages[2].string_contents().contains("file2.rs"));
3461 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3462
3463 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3464 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3465 assert!(request.messages[3].string_contents().contains("file3.rs"));
3466
3467 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3468 .await
3469 .unwrap();
3470 let new_contexts = context_store.update(cx, |store, cx| {
3471 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3472 });
3473 assert_eq!(new_contexts.len(), 3);
3474 let loaded_context = cx
3475 .update(|cx| load_context(new_contexts, &project, &None, cx))
3476 .await
3477 .loaded_context;
3478
3479 assert!(!loaded_context.text.contains("file1.rs"));
3480 assert!(loaded_context.text.contains("file2.rs"));
3481 assert!(loaded_context.text.contains("file3.rs"));
3482 assert!(loaded_context.text.contains("file4.rs"));
3483
3484 let new_contexts = context_store.update(cx, |store, cx| {
3485 // Remove file4.rs
3486 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3487 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3488 });
3489 assert_eq!(new_contexts.len(), 2);
3490 let loaded_context = cx
3491 .update(|cx| load_context(new_contexts, &project, &None, cx))
3492 .await
3493 .loaded_context;
3494
3495 assert!(!loaded_context.text.contains("file1.rs"));
3496 assert!(loaded_context.text.contains("file2.rs"));
3497 assert!(loaded_context.text.contains("file3.rs"));
3498 assert!(!loaded_context.text.contains("file4.rs"));
3499
3500 let new_contexts = context_store.update(cx, |store, cx| {
3501 // Remove file3.rs
3502 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3503 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3504 });
3505 assert_eq!(new_contexts.len(), 1);
3506 let loaded_context = cx
3507 .update(|cx| load_context(new_contexts, &project, &None, cx))
3508 .await
3509 .loaded_context;
3510
3511 assert!(!loaded_context.text.contains("file1.rs"));
3512 assert!(loaded_context.text.contains("file2.rs"));
3513 assert!(!loaded_context.text.contains("file3.rs"));
3514 assert!(!loaded_context.text.contains("file4.rs"));
3515 }
3516
3517 #[gpui::test]
3518 async fn test_message_without_files(cx: &mut TestAppContext) {
3519 let fs = init_test_settings(cx);
3520
3521 let project = create_test_project(
3522 &fs,
3523 cx,
3524 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3525 )
3526 .await;
3527
3528 let (_, _thread_store, thread, _context_store, model) =
3529 setup_test_environment(cx, project.clone()).await;
3530
3531 // Insert user message without any context (empty context vector)
3532 let message_id = thread.update(cx, |thread, cx| {
3533 thread.insert_user_message(
3534 "What is the best way to learn Rust?",
3535 ContextLoadResult::default(),
3536 None,
3537 Vec::new(),
3538 cx,
3539 )
3540 });
3541
3542 // Check content and context in message object
3543 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3544
3545 // Context should be empty when no files are included
3546 assert_eq!(message.role, Role::User);
3547 assert_eq!(message.segments.len(), 1);
3548 assert_eq!(
3549 message.segments[0],
3550 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3551 );
3552 assert_eq!(message.loaded_context.text, "");
3553
3554 // Check message in request
3555 let request = thread.update(cx, |thread, cx| {
3556 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3557 });
3558
3559 assert_eq!(request.messages.len(), 2);
3560 assert_eq!(
3561 request.messages[1].string_contents(),
3562 "What is the best way to learn Rust?"
3563 );
3564
3565 // Add second message, also without context
3566 let message2_id = thread.update(cx, |thread, cx| {
3567 thread.insert_user_message(
3568 "Are there any good books?",
3569 ContextLoadResult::default(),
3570 None,
3571 Vec::new(),
3572 cx,
3573 )
3574 });
3575
3576 let message2 =
3577 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3578 assert_eq!(message2.loaded_context.text, "");
3579
3580 // Check that both messages appear in the request
3581 let request = thread.update(cx, |thread, cx| {
3582 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3583 });
3584
3585 assert_eq!(request.messages.len(), 3);
3586 assert_eq!(
3587 request.messages[1].string_contents(),
3588 "What is the best way to learn Rust?"
3589 );
3590 assert_eq!(
3591 request.messages[2].string_contents(),
3592 "Are there any good books?"
3593 );
3594 }
3595
3596 #[gpui::test]
3597 #[ignore] // turn this test on when project_notifications tool is re-enabled
3598 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3599 let fs = init_test_settings(cx);
3600
3601 let project = create_test_project(
3602 &fs,
3603 cx,
3604 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3605 )
3606 .await;
3607
3608 let (_workspace, _thread_store, thread, context_store, model) =
3609 setup_test_environment(cx, project.clone()).await;
3610
3611 // Add a buffer to the context. This will be a tracked buffer
3612 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3613 .await
3614 .unwrap();
3615
3616 let context = context_store
3617 .read_with(cx, |store, _| store.context().next().cloned())
3618 .unwrap();
3619 let loaded_context = cx
3620 .update(|cx| load_context(vec![context], &project, &None, cx))
3621 .await;
3622
3623 // Insert user message and assistant response
3624 thread.update(cx, |thread, cx| {
3625 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3626 thread.insert_assistant_message(
3627 vec![MessageSegment::Text("This code prints 42.".into())],
3628 cx,
3629 );
3630 });
3631 cx.run_until_parked();
3632
3633 // We shouldn't have a stale buffer notification yet
3634 let notifications = thread.read_with(cx, |thread, _| {
3635 find_tool_uses(thread, "project_notifications")
3636 });
3637 assert!(
3638 notifications.is_empty(),
3639 "Should not have stale buffer notification before buffer is modified"
3640 );
3641
3642 // Modify the buffer
3643 buffer.update(cx, |buffer, cx| {
3644 buffer.edit(
3645 [(1..1, "\n println!(\"Added a new line\");\n")],
3646 None,
3647 cx,
3648 );
3649 });
3650
3651 // Insert another user message
3652 thread.update(cx, |thread, cx| {
3653 thread.insert_user_message(
3654 "What does the code do now?",
3655 ContextLoadResult::default(),
3656 None,
3657 Vec::new(),
3658 cx,
3659 )
3660 });
3661 cx.run_until_parked();
3662
3663 // Check for the stale buffer warning
3664 thread.update(cx, |thread, cx| {
3665 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3666 });
3667 cx.run_until_parked();
3668
3669 let notifications = thread.read_with(cx, |thread, _cx| {
3670 find_tool_uses(thread, "project_notifications")
3671 });
3672
3673 let [notification] = notifications.as_slice() else {
3674 panic!("Should have a `project_notifications` tool use");
3675 };
3676
3677 let Some(notification_content) = notification.content.to_str() else {
3678 panic!("`project_notifications` should return text");
3679 };
3680
3681 assert!(notification_content.contains("These files have changed since the last read:"));
3682 assert!(notification_content.contains("code.rs"));
3683
3684 // Insert another user message and flush notifications again
3685 thread.update(cx, |thread, cx| {
3686 thread.insert_user_message(
3687 "Can you tell me more?",
3688 ContextLoadResult::default(),
3689 None,
3690 Vec::new(),
3691 cx,
3692 )
3693 });
3694
3695 thread.update(cx, |thread, cx| {
3696 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3697 });
3698 cx.run_until_parked();
3699
3700 // There should be no new notifications (we already flushed one)
3701 let notifications = thread.read_with(cx, |thread, _cx| {
3702 find_tool_uses(thread, "project_notifications")
3703 });
3704
3705 assert_eq!(
3706 notifications.len(),
3707 1,
3708 "Should still have only one notification after second flush - no duplicates"
3709 );
3710 }
3711
3712 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3713 thread
3714 .messages()
3715 .flat_map(|message| {
3716 thread
3717 .tool_results_for_message(message.id)
3718 .into_iter()
3719 .filter(|result| result.tool_name == tool_name.into())
3720 .cloned()
3721 .collect::<Vec<_>>()
3722 })
3723 .collect()
3724 }
3725
3726 #[gpui::test]
3727 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3728 let fs = init_test_settings(cx);
3729
3730 let project = create_test_project(
3731 &fs,
3732 cx,
3733 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3734 )
3735 .await;
3736
3737 let (_workspace, thread_store, thread, _context_store, _model) =
3738 setup_test_environment(cx, project.clone()).await;
3739
3740 // Check that we are starting with the default profile
3741 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3742 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3743 assert_eq!(
3744 profile,
3745 AgentProfile::new(AgentProfileId::default(), tool_set)
3746 );
3747 }
3748
3749 #[gpui::test]
3750 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3751 let fs = init_test_settings(cx);
3752
3753 let project = create_test_project(
3754 &fs,
3755 cx,
3756 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3757 )
3758 .await;
3759
3760 let (_workspace, thread_store, thread, _context_store, _model) =
3761 setup_test_environment(cx, project.clone()).await;
3762
3763 // Profile gets serialized with default values
3764 let serialized = thread
3765 .update(cx, |thread, cx| thread.serialize(cx))
3766 .await
3767 .unwrap();
3768
3769 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3770
3771 let deserialized = cx.update(|cx| {
3772 thread.update(cx, |thread, cx| {
3773 Thread::deserialize(
3774 thread.id.clone(),
3775 serialized,
3776 thread.project.clone(),
3777 thread.tools.clone(),
3778 thread.prompt_builder.clone(),
3779 thread.project_context.clone(),
3780 None,
3781 cx,
3782 )
3783 })
3784 });
3785 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3786
3787 assert_eq!(
3788 deserialized.profile,
3789 AgentProfile::new(AgentProfileId::default(), tool_set)
3790 );
3791 }
3792
3793 #[gpui::test]
3794 async fn test_temperature_setting(cx: &mut TestAppContext) {
3795 let fs = init_test_settings(cx);
3796
3797 let project = create_test_project(
3798 &fs,
3799 cx,
3800 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3801 )
3802 .await;
3803
3804 let (_workspace, _thread_store, thread, _context_store, model) =
3805 setup_test_environment(cx, project.clone()).await;
3806
3807 // Both model and provider
3808 cx.update(|cx| {
3809 AgentSettings::override_global(
3810 AgentSettings {
3811 model_parameters: vec![LanguageModelParameters {
3812 provider: Some(model.provider_id().0.to_string().into()),
3813 model: Some(model.id().0),
3814 temperature: Some(0.66),
3815 }],
3816 ..AgentSettings::get_global(cx).clone()
3817 },
3818 cx,
3819 );
3820 });
3821
3822 let request = thread.update(cx, |thread, cx| {
3823 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3824 });
3825 assert_eq!(request.temperature, Some(0.66));
3826
3827 // Only model
3828 cx.update(|cx| {
3829 AgentSettings::override_global(
3830 AgentSettings {
3831 model_parameters: vec![LanguageModelParameters {
3832 provider: None,
3833 model: Some(model.id().0),
3834 temperature: Some(0.66),
3835 }],
3836 ..AgentSettings::get_global(cx).clone()
3837 },
3838 cx,
3839 );
3840 });
3841
3842 let request = thread.update(cx, |thread, cx| {
3843 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3844 });
3845 assert_eq!(request.temperature, Some(0.66));
3846
3847 // Only provider
3848 cx.update(|cx| {
3849 AgentSettings::override_global(
3850 AgentSettings {
3851 model_parameters: vec![LanguageModelParameters {
3852 provider: Some(model.provider_id().0.to_string().into()),
3853 model: None,
3854 temperature: Some(0.66),
3855 }],
3856 ..AgentSettings::get_global(cx).clone()
3857 },
3858 cx,
3859 );
3860 });
3861
3862 let request = thread.update(cx, |thread, cx| {
3863 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3864 });
3865 assert_eq!(request.temperature, Some(0.66));
3866
3867 // Same model name, different provider
3868 cx.update(|cx| {
3869 AgentSettings::override_global(
3870 AgentSettings {
3871 model_parameters: vec![LanguageModelParameters {
3872 provider: Some("anthropic".into()),
3873 model: Some(model.id().0),
3874 temperature: Some(0.66),
3875 }],
3876 ..AgentSettings::get_global(cx).clone()
3877 },
3878 cx,
3879 );
3880 });
3881
3882 let request = thread.update(cx, |thread, cx| {
3883 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3884 });
3885 assert_eq!(request.temperature, None);
3886 }
3887
3888 #[gpui::test]
3889 async fn test_thread_summary(cx: &mut TestAppContext) {
3890 let fs = init_test_settings(cx);
3891
3892 let project = create_test_project(&fs, cx, json!({})).await;
3893
3894 let (_, _thread_store, thread, _context_store, model) =
3895 setup_test_environment(cx, project.clone()).await;
3896
3897 // Initial state should be pending
3898 thread.read_with(cx, |thread, _| {
3899 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3900 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3901 });
3902
3903 // Manually setting the summary should not be allowed in this state
3904 thread.update(cx, |thread, cx| {
3905 thread.set_summary("This should not work", cx);
3906 });
3907
3908 thread.read_with(cx, |thread, _| {
3909 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3910 });
3911
3912 // Send a message
3913 thread.update(cx, |thread, cx| {
3914 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3915 thread.send_to_model(
3916 model.clone(),
3917 CompletionIntent::ThreadSummarization,
3918 None,
3919 cx,
3920 );
3921 });
3922
3923 let fake_model = model.as_fake();
3924 simulate_successful_response(fake_model, cx);
3925
3926 // Should start generating summary when there are >= 2 messages
3927 thread.read_with(cx, |thread, _| {
3928 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3929 });
3930
3931 // Should not be able to set the summary while generating
3932 thread.update(cx, |thread, cx| {
3933 thread.set_summary("This should not work either", cx);
3934 });
3935
3936 thread.read_with(cx, |thread, _| {
3937 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3938 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3939 });
3940
3941 cx.run_until_parked();
3942 fake_model.send_last_completion_stream_text_chunk("Brief");
3943 fake_model.send_last_completion_stream_text_chunk(" Introduction");
3944 fake_model.end_last_completion_stream();
3945 cx.run_until_parked();
3946
3947 // Summary should be set
3948 thread.read_with(cx, |thread, _| {
3949 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3950 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3951 });
3952
3953 // Now we should be able to set a summary
3954 thread.update(cx, |thread, cx| {
3955 thread.set_summary("Brief Intro", cx);
3956 });
3957
3958 thread.read_with(cx, |thread, _| {
3959 assert_eq!(thread.summary().or_default(), "Brief Intro");
3960 });
3961
3962 // Test setting an empty summary (should default to DEFAULT)
3963 thread.update(cx, |thread, cx| {
3964 thread.set_summary("", cx);
3965 });
3966
3967 thread.read_with(cx, |thread, _| {
3968 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3969 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3970 });
3971 }
3972
3973 #[gpui::test]
3974 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3975 let fs = init_test_settings(cx);
3976
3977 let project = create_test_project(&fs, cx, json!({})).await;
3978
3979 let (_, _thread_store, thread, _context_store, model) =
3980 setup_test_environment(cx, project.clone()).await;
3981
3982 test_summarize_error(&model, &thread, cx);
3983
3984 // Now we should be able to set a summary
3985 thread.update(cx, |thread, cx| {
3986 thread.set_summary("Brief Intro", cx);
3987 });
3988
3989 thread.read_with(cx, |thread, _| {
3990 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3991 assert_eq!(thread.summary().or_default(), "Brief Intro");
3992 });
3993 }
3994
3995 #[gpui::test]
3996 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3997 let fs = init_test_settings(cx);
3998
3999 let project = create_test_project(&fs, cx, json!({})).await;
4000
4001 let (_, _thread_store, thread, _context_store, model) =
4002 setup_test_environment(cx, project.clone()).await;
4003
4004 test_summarize_error(&model, &thread, cx);
4005
4006 // Sending another message should not trigger another summarize request
4007 thread.update(cx, |thread, cx| {
4008 thread.insert_user_message(
4009 "How are you?",
4010 ContextLoadResult::default(),
4011 None,
4012 vec![],
4013 cx,
4014 );
4015 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4016 });
4017
4018 let fake_model = model.as_fake();
4019 simulate_successful_response(fake_model, cx);
4020
4021 thread.read_with(cx, |thread, _| {
4022 // State is still Error, not Generating
4023 assert!(matches!(thread.summary(), ThreadSummary::Error));
4024 });
4025
4026 // But the summarize request can be invoked manually
4027 thread.update(cx, |thread, cx| {
4028 thread.summarize(cx);
4029 });
4030
4031 thread.read_with(cx, |thread, _| {
4032 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4033 });
4034
4035 cx.run_until_parked();
4036 fake_model.send_last_completion_stream_text_chunk("A successful summary");
4037 fake_model.end_last_completion_stream();
4038 cx.run_until_parked();
4039
4040 thread.read_with(cx, |thread, _| {
4041 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4042 assert_eq!(thread.summary().or_default(), "A successful summary");
4043 });
4044 }
4045
4046 // Helper to create a model that returns errors
4047 enum TestError {
4048 Overloaded,
4049 InternalServerError,
4050 }
4051
4052 struct ErrorInjector {
4053 inner: Arc<FakeLanguageModel>,
4054 error_type: TestError,
4055 }
4056
4057 impl ErrorInjector {
4058 fn new(error_type: TestError) -> Self {
4059 Self {
4060 inner: Arc::new(FakeLanguageModel::default()),
4061 error_type,
4062 }
4063 }
4064 }
4065
4066 impl LanguageModel for ErrorInjector {
4067 fn id(&self) -> LanguageModelId {
4068 self.inner.id()
4069 }
4070
4071 fn name(&self) -> LanguageModelName {
4072 self.inner.name()
4073 }
4074
4075 fn provider_id(&self) -> LanguageModelProviderId {
4076 self.inner.provider_id()
4077 }
4078
4079 fn provider_name(&self) -> LanguageModelProviderName {
4080 self.inner.provider_name()
4081 }
4082
4083 fn supports_tools(&self) -> bool {
4084 self.inner.supports_tools()
4085 }
4086
4087 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4088 self.inner.supports_tool_choice(choice)
4089 }
4090
4091 fn supports_images(&self) -> bool {
4092 self.inner.supports_images()
4093 }
4094
4095 fn telemetry_id(&self) -> String {
4096 self.inner.telemetry_id()
4097 }
4098
4099 fn max_token_count(&self) -> u64 {
4100 self.inner.max_token_count()
4101 }
4102
4103 fn count_tokens(
4104 &self,
4105 request: LanguageModelRequest,
4106 cx: &App,
4107 ) -> BoxFuture<'static, Result<u64>> {
4108 self.inner.count_tokens(request, cx)
4109 }
4110
4111 fn stream_completion(
4112 &self,
4113 _request: LanguageModelRequest,
4114 _cx: &AsyncApp,
4115 ) -> BoxFuture<
4116 'static,
4117 Result<
4118 BoxStream<
4119 'static,
4120 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4121 >,
4122 LanguageModelCompletionError,
4123 >,
4124 > {
4125 let error = match self.error_type {
4126 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4127 provider: self.provider_name(),
4128 retry_after: None,
4129 },
4130 TestError::InternalServerError => {
4131 LanguageModelCompletionError::ApiInternalServerError {
4132 provider: self.provider_name(),
4133 message: "I'm a teapot orbiting the sun".to_string(),
4134 }
4135 }
4136 };
4137 async move {
4138 let stream = futures::stream::once(async move { Err(error) });
4139 Ok(stream.boxed())
4140 }
4141 .boxed()
4142 }
4143
4144 fn as_fake(&self) -> &FakeLanguageModel {
4145 &self.inner
4146 }
4147 }
4148
4149 #[gpui::test]
4150 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4151 let fs = init_test_settings(cx);
4152
4153 let project = create_test_project(&fs, cx, json!({})).await;
4154 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4155
4156 // Enable Burn Mode to allow retries
4157 thread.update(cx, |thread, _| {
4158 thread.set_completion_mode(CompletionMode::Burn);
4159 });
4160
4161 // Create model that returns overloaded error
4162 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4163
4164 // Insert a user message
4165 thread.update(cx, |thread, cx| {
4166 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4167 });
4168
4169 // Start completion
4170 thread.update(cx, |thread, cx| {
4171 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4172 });
4173
4174 cx.run_until_parked();
4175
4176 thread.read_with(cx, |thread, _| {
4177 assert!(thread.retry_state.is_some(), "Should have retry state");
4178 let retry_state = thread.retry_state.as_ref().unwrap();
4179 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4180 assert_eq!(
4181 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4182 "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4183 );
4184 });
4185
4186 // Check that a retry message was added
4187 thread.read_with(cx, |thread, _| {
4188 let mut messages = thread.messages();
4189 assert!(
4190 messages.any(|msg| {
4191 msg.role == Role::System
4192 && msg.ui_only
4193 && msg.segments.iter().any(|seg| {
4194 if let MessageSegment::Text(text) = seg {
4195 text.contains("overloaded")
4196 && text
4197 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4198 } else {
4199 false
4200 }
4201 })
4202 }),
4203 "Should have added a system retry message"
4204 );
4205 });
4206
4207 let retry_count = thread.update(cx, |thread, _| {
4208 thread
4209 .messages
4210 .iter()
4211 .filter(|m| {
4212 m.ui_only
4213 && m.segments.iter().any(|s| {
4214 if let MessageSegment::Text(text) = s {
4215 text.contains("Retrying") && text.contains("seconds")
4216 } else {
4217 false
4218 }
4219 })
4220 })
4221 .count()
4222 });
4223
4224 assert_eq!(retry_count, 1, "Should have one retry message");
4225 }
4226
4227 #[gpui::test]
4228 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4229 let fs = init_test_settings(cx);
4230
4231 let project = create_test_project(&fs, cx, json!({})).await;
4232 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4233
4234 // Enable Burn Mode to allow retries
4235 thread.update(cx, |thread, _| {
4236 thread.set_completion_mode(CompletionMode::Burn);
4237 });
4238
4239 // Create model that returns internal server error
4240 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4241
4242 // Insert a user message
4243 thread.update(cx, |thread, cx| {
4244 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4245 });
4246
4247 // Start completion
4248 thread.update(cx, |thread, cx| {
4249 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4250 });
4251
4252 cx.run_until_parked();
4253
4254 // Check retry state on thread
4255 thread.read_with(cx, |thread, _| {
4256 assert!(thread.retry_state.is_some(), "Should have retry state");
4257 let retry_state = thread.retry_state.as_ref().unwrap();
4258 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4259 assert_eq!(
4260 retry_state.max_attempts, 3,
4261 "Should have correct max attempts"
4262 );
4263 });
4264
4265 // Check that a retry message was added with provider name
4266 thread.read_with(cx, |thread, _| {
4267 let mut messages = thread.messages();
4268 assert!(
4269 messages.any(|msg| {
4270 msg.role == Role::System
4271 && msg.ui_only
4272 && msg.segments.iter().any(|seg| {
4273 if let MessageSegment::Text(text) = seg {
4274 text.contains("internal")
4275 && text.contains("Fake")
4276 && text.contains("Retrying")
4277 && text.contains("attempt 1 of 3")
4278 && text.contains("seconds")
4279 } else {
4280 false
4281 }
4282 })
4283 }),
4284 "Should have added a system retry message with provider name"
4285 );
4286 });
4287
4288 // Count retry messages
4289 let retry_count = thread.update(cx, |thread, _| {
4290 thread
4291 .messages
4292 .iter()
4293 .filter(|m| {
4294 m.ui_only
4295 && m.segments.iter().any(|s| {
4296 if let MessageSegment::Text(text) = s {
4297 text.contains("Retrying") && text.contains("seconds")
4298 } else {
4299 false
4300 }
4301 })
4302 })
4303 .count()
4304 });
4305
4306 assert_eq!(retry_count, 1, "Should have one retry message");
4307 }
4308
4309 #[gpui::test]
4310 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4311 let fs = init_test_settings(cx);
4312
4313 let project = create_test_project(&fs, cx, json!({})).await;
4314 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4315
4316 // Enable Burn Mode to allow retries
4317 thread.update(cx, |thread, _| {
4318 thread.set_completion_mode(CompletionMode::Burn);
4319 });
4320
4321 // Create model that returns internal server error
4322 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4323
4324 // Insert a user message
4325 thread.update(cx, |thread, cx| {
4326 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4327 });
4328
4329 // Track retry events and completion count
4330 // Track completion events
4331 let completion_count = Arc::new(Mutex::new(0));
4332 let completion_count_clone = completion_count.clone();
4333
4334 let _subscription = thread.update(cx, |_, cx| {
4335 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4336 if let ThreadEvent::NewRequest = event {
4337 *completion_count_clone.lock() += 1;
4338 }
4339 })
4340 });
4341
4342 // First attempt
4343 thread.update(cx, |thread, cx| {
4344 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4345 });
4346 cx.run_until_parked();
4347
4348 // Should have scheduled first retry - count retry messages
4349 let retry_count = thread.update(cx, |thread, _| {
4350 thread
4351 .messages
4352 .iter()
4353 .filter(|m| {
4354 m.ui_only
4355 && m.segments.iter().any(|s| {
4356 if let MessageSegment::Text(text) = s {
4357 text.contains("Retrying") && text.contains("seconds")
4358 } else {
4359 false
4360 }
4361 })
4362 })
4363 .count()
4364 });
4365 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4366
4367 // Check retry state
4368 thread.read_with(cx, |thread, _| {
4369 assert!(thread.retry_state.is_some(), "Should have retry state");
4370 let retry_state = thread.retry_state.as_ref().unwrap();
4371 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4372 assert_eq!(
4373 retry_state.max_attempts, 3,
4374 "Internal server errors should retry up to 3 times"
4375 );
4376 });
4377
4378 // Advance clock for first retry
4379 cx.executor().advance_clock(BASE_RETRY_DELAY);
4380 cx.run_until_parked();
4381
4382 // Advance clock for second retry
4383 cx.executor().advance_clock(BASE_RETRY_DELAY);
4384 cx.run_until_parked();
4385
4386 // Advance clock for third retry
4387 cx.executor().advance_clock(BASE_RETRY_DELAY);
4388 cx.run_until_parked();
4389
4390 // Should have completed all retries - count retry messages
4391 let retry_count = thread.update(cx, |thread, _| {
4392 thread
4393 .messages
4394 .iter()
4395 .filter(|m| {
4396 m.ui_only
4397 && m.segments.iter().any(|s| {
4398 if let MessageSegment::Text(text) = s {
4399 text.contains("Retrying") && text.contains("seconds")
4400 } else {
4401 false
4402 }
4403 })
4404 })
4405 .count()
4406 });
4407 assert_eq!(
4408 retry_count, 3,
4409 "Should have 3 retries for internal server errors"
4410 );
4411
4412 // For internal server errors, we retry 3 times and then give up
4413 // Check that retry_state is cleared after all retries
4414 thread.read_with(cx, |thread, _| {
4415 assert!(
4416 thread.retry_state.is_none(),
4417 "Retry state should be cleared after all retries"
4418 );
4419 });
4420
4421 // Verify total attempts (1 initial + 3 retries)
4422 assert_eq!(
4423 *completion_count.lock(),
4424 4,
4425 "Should have attempted once plus 3 retries"
4426 );
4427 }
4428
4429 #[gpui::test]
4430 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4431 let fs = init_test_settings(cx);
4432
4433 let project = create_test_project(&fs, cx, json!({})).await;
4434 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4435
4436 // Enable Burn Mode to allow retries
4437 thread.update(cx, |thread, _| {
4438 thread.set_completion_mode(CompletionMode::Burn);
4439 });
4440
4441 // Create model that returns overloaded error
4442 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4443
4444 // Insert a user message
4445 thread.update(cx, |thread, cx| {
4446 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4447 });
4448
4449 // Track events
4450 let stopped_with_error = Arc::new(Mutex::new(false));
4451 let stopped_with_error_clone = stopped_with_error.clone();
4452
4453 let _subscription = thread.update(cx, |_, cx| {
4454 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4455 if let ThreadEvent::Stopped(Err(_)) = event {
4456 *stopped_with_error_clone.lock() = true;
4457 }
4458 })
4459 });
4460
4461 // Start initial completion
4462 thread.update(cx, |thread, cx| {
4463 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4464 });
4465 cx.run_until_parked();
4466
4467 // Advance through all retries
4468 for _ in 0..MAX_RETRY_ATTEMPTS {
4469 cx.executor().advance_clock(BASE_RETRY_DELAY);
4470 cx.run_until_parked();
4471 }
4472
4473 let retry_count = thread.update(cx, |thread, _| {
4474 thread
4475 .messages
4476 .iter()
4477 .filter(|m| {
4478 m.ui_only
4479 && m.segments.iter().any(|s| {
4480 if let MessageSegment::Text(text) = s {
4481 text.contains("Retrying") && text.contains("seconds")
4482 } else {
4483 false
4484 }
4485 })
4486 })
4487 .count()
4488 });
4489
4490 // After max retries, should emit Stopped(Err(...)) event
4491 assert_eq!(
4492 retry_count, MAX_RETRY_ATTEMPTS as usize,
4493 "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4494 );
4495 assert!(
4496 *stopped_with_error.lock(),
4497 "Should emit Stopped(Err(...)) event after max retries exceeded"
4498 );
4499
4500 // Retry state should be cleared
4501 thread.read_with(cx, |thread, _| {
4502 assert!(
4503 thread.retry_state.is_none(),
4504 "Retry state should be cleared after max retries"
4505 );
4506
4507 // Verify we have the expected number of retry messages
4508 let retry_messages = thread
4509 .messages
4510 .iter()
4511 .filter(|msg| msg.ui_only && msg.role == Role::System)
4512 .count();
4513 assert_eq!(
4514 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4515 "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4516 );
4517 });
4518 }
4519
4520 #[gpui::test]
4521 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4522 let fs = init_test_settings(cx);
4523
4524 let project = create_test_project(&fs, cx, json!({})).await;
4525 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4526
4527 // Enable Burn Mode to allow retries
4528 thread.update(cx, |thread, _| {
4529 thread.set_completion_mode(CompletionMode::Burn);
4530 });
4531
4532 // We'll use a wrapper to switch behavior after first failure
4533 struct RetryTestModel {
4534 inner: Arc<FakeLanguageModel>,
4535 failed_once: Arc<Mutex<bool>>,
4536 }
4537
4538 impl LanguageModel for RetryTestModel {
4539 fn id(&self) -> LanguageModelId {
4540 self.inner.id()
4541 }
4542
4543 fn name(&self) -> LanguageModelName {
4544 self.inner.name()
4545 }
4546
4547 fn provider_id(&self) -> LanguageModelProviderId {
4548 self.inner.provider_id()
4549 }
4550
4551 fn provider_name(&self) -> LanguageModelProviderName {
4552 self.inner.provider_name()
4553 }
4554
4555 fn supports_tools(&self) -> bool {
4556 self.inner.supports_tools()
4557 }
4558
4559 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4560 self.inner.supports_tool_choice(choice)
4561 }
4562
4563 fn supports_images(&self) -> bool {
4564 self.inner.supports_images()
4565 }
4566
4567 fn telemetry_id(&self) -> String {
4568 self.inner.telemetry_id()
4569 }
4570
4571 fn max_token_count(&self) -> u64 {
4572 self.inner.max_token_count()
4573 }
4574
4575 fn count_tokens(
4576 &self,
4577 request: LanguageModelRequest,
4578 cx: &App,
4579 ) -> BoxFuture<'static, Result<u64>> {
4580 self.inner.count_tokens(request, cx)
4581 }
4582
4583 fn stream_completion(
4584 &self,
4585 request: LanguageModelRequest,
4586 cx: &AsyncApp,
4587 ) -> BoxFuture<
4588 'static,
4589 Result<
4590 BoxStream<
4591 'static,
4592 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4593 >,
4594 LanguageModelCompletionError,
4595 >,
4596 > {
4597 if !*self.failed_once.lock() {
4598 *self.failed_once.lock() = true;
4599 let provider = self.provider_name();
4600 // Return error on first attempt
4601 let stream = futures::stream::once(async move {
4602 Err(LanguageModelCompletionError::ServerOverloaded {
4603 provider,
4604 retry_after: None,
4605 })
4606 });
4607 async move { Ok(stream.boxed()) }.boxed()
4608 } else {
4609 // Succeed on retry
4610 self.inner.stream_completion(request, cx)
4611 }
4612 }
4613
4614 fn as_fake(&self) -> &FakeLanguageModel {
4615 &self.inner
4616 }
4617 }
4618
4619 let model = Arc::new(RetryTestModel {
4620 inner: Arc::new(FakeLanguageModel::default()),
4621 failed_once: Arc::new(Mutex::new(false)),
4622 });
4623
4624 // Insert a user message
4625 thread.update(cx, |thread, cx| {
4626 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4627 });
4628
4629 // Track message deletions
4630 // Track when retry completes successfully
4631 let retry_completed = Arc::new(Mutex::new(false));
4632 let retry_completed_clone = retry_completed.clone();
4633
4634 let _subscription = thread.update(cx, |_, cx| {
4635 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4636 if let ThreadEvent::StreamedCompletion = event {
4637 *retry_completed_clone.lock() = true;
4638 }
4639 })
4640 });
4641
4642 // Start completion
4643 thread.update(cx, |thread, cx| {
4644 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4645 });
4646 cx.run_until_parked();
4647
4648 // Get the retry message ID
4649 let retry_message_id = thread.read_with(cx, |thread, _| {
4650 thread
4651 .messages()
4652 .find(|msg| msg.role == Role::System && msg.ui_only)
4653 .map(|msg| msg.id)
4654 .expect("Should have a retry message")
4655 });
4656
4657 // Wait for retry
4658 cx.executor().advance_clock(BASE_RETRY_DELAY);
4659 cx.run_until_parked();
4660
4661 // Stream some successful content
4662 let fake_model = model.as_fake();
4663 // After the retry, there should be a new pending completion
4664 let pending = fake_model.pending_completions();
4665 assert!(
4666 !pending.is_empty(),
4667 "Should have a pending completion after retry"
4668 );
4669 fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
4670 fake_model.end_completion_stream(&pending[0]);
4671 cx.run_until_parked();
4672
4673 // Check that the retry completed successfully
4674 assert!(
4675 *retry_completed.lock(),
4676 "Retry should have completed successfully"
4677 );
4678
4679 // Retry message should still exist but be marked as ui_only
4680 thread.read_with(cx, |thread, _| {
4681 let retry_msg = thread
4682 .message(retry_message_id)
4683 .expect("Retry message should still exist");
4684 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4685 assert_eq!(
4686 retry_msg.role,
4687 Role::System,
4688 "Retry message should have System role"
4689 );
4690 });
4691 }
4692
4693 #[gpui::test]
4694 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4695 let fs = init_test_settings(cx);
4696
4697 let project = create_test_project(&fs, cx, json!({})).await;
4698 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4699
4700 // Enable Burn Mode to allow retries
4701 thread.update(cx, |thread, _| {
4702 thread.set_completion_mode(CompletionMode::Burn);
4703 });
4704
4705 // Create a model that fails once then succeeds
4706 struct FailOnceModel {
4707 inner: Arc<FakeLanguageModel>,
4708 failed_once: Arc<Mutex<bool>>,
4709 }
4710
4711 impl LanguageModel for FailOnceModel {
4712 fn id(&self) -> LanguageModelId {
4713 self.inner.id()
4714 }
4715
4716 fn name(&self) -> LanguageModelName {
4717 self.inner.name()
4718 }
4719
4720 fn provider_id(&self) -> LanguageModelProviderId {
4721 self.inner.provider_id()
4722 }
4723
4724 fn provider_name(&self) -> LanguageModelProviderName {
4725 self.inner.provider_name()
4726 }
4727
4728 fn supports_tools(&self) -> bool {
4729 self.inner.supports_tools()
4730 }
4731
4732 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4733 self.inner.supports_tool_choice(choice)
4734 }
4735
4736 fn supports_images(&self) -> bool {
4737 self.inner.supports_images()
4738 }
4739
4740 fn telemetry_id(&self) -> String {
4741 self.inner.telemetry_id()
4742 }
4743
4744 fn max_token_count(&self) -> u64 {
4745 self.inner.max_token_count()
4746 }
4747
4748 fn count_tokens(
4749 &self,
4750 request: LanguageModelRequest,
4751 cx: &App,
4752 ) -> BoxFuture<'static, Result<u64>> {
4753 self.inner.count_tokens(request, cx)
4754 }
4755
4756 fn stream_completion(
4757 &self,
4758 request: LanguageModelRequest,
4759 cx: &AsyncApp,
4760 ) -> BoxFuture<
4761 'static,
4762 Result<
4763 BoxStream<
4764 'static,
4765 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4766 >,
4767 LanguageModelCompletionError,
4768 >,
4769 > {
4770 if !*self.failed_once.lock() {
4771 *self.failed_once.lock() = true;
4772 let provider = self.provider_name();
4773 // Return error on first attempt
4774 let stream = futures::stream::once(async move {
4775 Err(LanguageModelCompletionError::ServerOverloaded {
4776 provider,
4777 retry_after: None,
4778 })
4779 });
4780 async move { Ok(stream.boxed()) }.boxed()
4781 } else {
4782 // Succeed on retry
4783 self.inner.stream_completion(request, cx)
4784 }
4785 }
4786 }
4787
4788 let fail_once_model = Arc::new(FailOnceModel {
4789 inner: Arc::new(FakeLanguageModel::default()),
4790 failed_once: Arc::new(Mutex::new(false)),
4791 });
4792
4793 // Insert a user message
4794 thread.update(cx, |thread, cx| {
4795 thread.insert_user_message(
4796 "Test message",
4797 ContextLoadResult::default(),
4798 None,
4799 vec![],
4800 cx,
4801 );
4802 });
4803
4804 // Start completion with fail-once model
4805 thread.update(cx, |thread, cx| {
4806 thread.send_to_model(
4807 fail_once_model.clone(),
4808 CompletionIntent::UserPrompt,
4809 None,
4810 cx,
4811 );
4812 });
4813
4814 cx.run_until_parked();
4815
4816 // Verify retry state exists after first failure
4817 thread.read_with(cx, |thread, _| {
4818 assert!(
4819 thread.retry_state.is_some(),
4820 "Should have retry state after failure"
4821 );
4822 });
4823
4824 // Wait for retry delay
4825 cx.executor().advance_clock(BASE_RETRY_DELAY);
4826 cx.run_until_parked();
4827
4828 // The retry should now use our FailOnceModel which should succeed
4829 // We need to help the FakeLanguageModel complete the stream
4830 let inner_fake = fail_once_model.inner.clone();
4831
4832 // Wait a bit for the retry to start
4833 cx.run_until_parked();
4834
4835 // Check for pending completions and complete them
4836 if let Some(pending) = inner_fake.pending_completions().first() {
4837 inner_fake.send_completion_stream_text_chunk(pending, "Success!");
4838 inner_fake.end_completion_stream(pending);
4839 }
4840 cx.run_until_parked();
4841
4842 thread.read_with(cx, |thread, _| {
4843 assert!(
4844 thread.retry_state.is_none(),
4845 "Retry state should be cleared after successful completion"
4846 );
4847
4848 let has_assistant_message = thread
4849 .messages
4850 .iter()
4851 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4852 assert!(
4853 has_assistant_message,
4854 "Should have an assistant message after successful retry"
4855 );
4856 });
4857 }
4858
4859 #[gpui::test]
4860 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4861 let fs = init_test_settings(cx);
4862
4863 let project = create_test_project(&fs, cx, json!({})).await;
4864 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4865
4866 // Enable Burn Mode to allow retries
4867 thread.update(cx, |thread, _| {
4868 thread.set_completion_mode(CompletionMode::Burn);
4869 });
4870
4871 // Create a model that returns rate limit error with retry_after
4872 struct RateLimitModel {
4873 inner: Arc<FakeLanguageModel>,
4874 }
4875
4876 impl LanguageModel for RateLimitModel {
4877 fn id(&self) -> LanguageModelId {
4878 self.inner.id()
4879 }
4880
4881 fn name(&self) -> LanguageModelName {
4882 self.inner.name()
4883 }
4884
4885 fn provider_id(&self) -> LanguageModelProviderId {
4886 self.inner.provider_id()
4887 }
4888
4889 fn provider_name(&self) -> LanguageModelProviderName {
4890 self.inner.provider_name()
4891 }
4892
4893 fn supports_tools(&self) -> bool {
4894 self.inner.supports_tools()
4895 }
4896
4897 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4898 self.inner.supports_tool_choice(choice)
4899 }
4900
4901 fn supports_images(&self) -> bool {
4902 self.inner.supports_images()
4903 }
4904
4905 fn telemetry_id(&self) -> String {
4906 self.inner.telemetry_id()
4907 }
4908
4909 fn max_token_count(&self) -> u64 {
4910 self.inner.max_token_count()
4911 }
4912
4913 fn count_tokens(
4914 &self,
4915 request: LanguageModelRequest,
4916 cx: &App,
4917 ) -> BoxFuture<'static, Result<u64>> {
4918 self.inner.count_tokens(request, cx)
4919 }
4920
4921 fn stream_completion(
4922 &self,
4923 _request: LanguageModelRequest,
4924 _cx: &AsyncApp,
4925 ) -> BoxFuture<
4926 'static,
4927 Result<
4928 BoxStream<
4929 'static,
4930 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4931 >,
4932 LanguageModelCompletionError,
4933 >,
4934 > {
4935 let provider = self.provider_name();
4936 async move {
4937 let stream = futures::stream::once(async move {
4938 Err(LanguageModelCompletionError::RateLimitExceeded {
4939 provider,
4940 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4941 })
4942 });
4943 Ok(stream.boxed())
4944 }
4945 .boxed()
4946 }
4947
4948 fn as_fake(&self) -> &FakeLanguageModel {
4949 &self.inner
4950 }
4951 }
4952
4953 let model = Arc::new(RateLimitModel {
4954 inner: Arc::new(FakeLanguageModel::default()),
4955 });
4956
4957 // Insert a user message
4958 thread.update(cx, |thread, cx| {
4959 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4960 });
4961
4962 // Start completion
4963 thread.update(cx, |thread, cx| {
4964 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4965 });
4966
4967 cx.run_until_parked();
4968
4969 let retry_count = thread.update(cx, |thread, _| {
4970 thread
4971 .messages
4972 .iter()
4973 .filter(|m| {
4974 m.ui_only
4975 && m.segments.iter().any(|s| {
4976 if let MessageSegment::Text(text) = s {
4977 text.contains("rate limit exceeded")
4978 } else {
4979 false
4980 }
4981 })
4982 })
4983 .count()
4984 });
4985 assert_eq!(retry_count, 1, "Should have scheduled one retry");
4986
4987 thread.read_with(cx, |thread, _| {
4988 assert!(
4989 thread.retry_state.is_some(),
4990 "Rate limit errors should set retry_state"
4991 );
4992 if let Some(retry_state) = &thread.retry_state {
4993 assert_eq!(
4994 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4995 "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4996 );
4997 }
4998 });
4999
5000 // Verify we have one retry message
5001 thread.read_with(cx, |thread, _| {
5002 let retry_messages = thread
5003 .messages
5004 .iter()
5005 .filter(|msg| {
5006 msg.ui_only
5007 && msg.segments.iter().any(|seg| {
5008 if let MessageSegment::Text(text) = seg {
5009 text.contains("rate limit exceeded")
5010 } else {
5011 false
5012 }
5013 })
5014 })
5015 .count();
5016 assert_eq!(
5017 retry_messages, 1,
5018 "Should have one rate limit retry message"
5019 );
5020 });
5021
5022 // Check that retry message doesn't include attempt count
5023 thread.read_with(cx, |thread, _| {
5024 let retry_message = thread
5025 .messages
5026 .iter()
5027 .find(|msg| msg.role == Role::System && msg.ui_only)
5028 .expect("Should have a retry message");
5029
5030 // Check that the message contains attempt count since we use retry_state
5031 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5032 assert!(
5033 text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
5034 "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
5035 );
5036 assert!(
5037 text.contains("Retrying"),
5038 "Rate limit retry message should contain retry text"
5039 );
5040 }
5041 });
5042 }
5043
5044 #[gpui::test]
5045 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5046 let fs = init_test_settings(cx);
5047
5048 let project = create_test_project(&fs, cx, json!({})).await;
5049 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5050
5051 // Insert a regular user message
5052 thread.update(cx, |thread, cx| {
5053 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5054 });
5055
5056 // Insert a UI-only message (like our retry notifications)
5057 thread.update(cx, |thread, cx| {
5058 let id = thread.next_message_id.post_inc();
5059 thread.messages.push(Message {
5060 id,
5061 role: Role::System,
5062 segments: vec![MessageSegment::Text(
5063 "This is a UI-only message that should not be sent to the model".to_string(),
5064 )],
5065 loaded_context: LoadedContext::default(),
5066 creases: Vec::new(),
5067 is_hidden: true,
5068 ui_only: true,
5069 });
5070 cx.emit(ThreadEvent::MessageAdded(id));
5071 });
5072
5073 // Insert another regular message
5074 thread.update(cx, |thread, cx| {
5075 thread.insert_user_message(
5076 "How are you?",
5077 ContextLoadResult::default(),
5078 None,
5079 vec![],
5080 cx,
5081 );
5082 });
5083
5084 // Generate the completion request
5085 let request = thread.update(cx, |thread, cx| {
5086 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5087 });
5088
5089 // Verify that the request only contains non-UI-only messages
5090 // Should have system prompt + 2 user messages, but not the UI-only message
5091 let user_messages: Vec<_> = request
5092 .messages
5093 .iter()
5094 .filter(|msg| msg.role == Role::User)
5095 .collect();
5096 assert_eq!(
5097 user_messages.len(),
5098 2,
5099 "Should have exactly 2 user messages"
5100 );
5101
5102 // Verify the UI-only content is not present anywhere in the request
5103 let request_text = request
5104 .messages
5105 .iter()
5106 .flat_map(|msg| &msg.content)
5107 .filter_map(|content| match content {
5108 MessageContent::Text(text) => Some(text.as_str()),
5109 _ => None,
5110 })
5111 .collect::<String>();
5112
5113 assert!(
5114 !request_text.contains("UI-only message"),
5115 "UI-only message content should not be in the request"
5116 );
5117
5118 // Verify the thread still has all 3 messages (including UI-only)
5119 thread.read_with(cx, |thread, _| {
5120 assert_eq!(
5121 thread.messages().count(),
5122 3,
5123 "Thread should have 3 messages"
5124 );
5125 assert_eq!(
5126 thread.messages().filter(|m| m.ui_only).count(),
5127 1,
5128 "Thread should have 1 UI-only message"
5129 );
5130 });
5131
5132 // Verify that UI-only messages are not serialized
5133 let serialized = thread
5134 .update(cx, |thread, cx| thread.serialize(cx))
5135 .await
5136 .unwrap();
5137 assert_eq!(
5138 serialized.messages.len(),
5139 2,
5140 "Serialized thread should only have 2 messages (no UI-only)"
5141 );
5142 }
5143
5144 #[gpui::test]
5145 async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
5146 let fs = init_test_settings(cx);
5147
5148 let project = create_test_project(&fs, cx, json!({})).await;
5149 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5150
5151 // Ensure we're in Normal mode (not Burn mode)
5152 thread.update(cx, |thread, _| {
5153 thread.set_completion_mode(CompletionMode::Normal);
5154 });
5155
5156 // Track error events
5157 let error_events = Arc::new(Mutex::new(Vec::new()));
5158 let error_events_clone = error_events.clone();
5159
5160 let _subscription = thread.update(cx, |_, cx| {
5161 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
5162 if let ThreadEvent::ShowError(error) = event {
5163 error_events_clone.lock().push(error.clone());
5164 }
5165 })
5166 });
5167
5168 // Create model that returns overloaded error
5169 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5170
5171 // Insert a user message
5172 thread.update(cx, |thread, cx| {
5173 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5174 });
5175
5176 // Start completion
5177 thread.update(cx, |thread, cx| {
5178 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5179 });
5180
5181 cx.run_until_parked();
5182
5183 // Verify no retry state was created
5184 thread.read_with(cx, |thread, _| {
5185 assert!(
5186 thread.retry_state.is_none(),
5187 "Should not have retry state in Normal mode"
5188 );
5189 });
5190
5191 // Check that a retryable error was reported
5192 let errors = error_events.lock();
5193 assert!(!errors.is_empty(), "Should have received an error event");
5194
5195 if let ThreadError::RetryableError {
5196 message: _,
5197 can_enable_burn_mode,
5198 } = &errors[0]
5199 {
5200 assert!(
5201 *can_enable_burn_mode,
5202 "Error should indicate burn mode can be enabled"
5203 );
5204 } else {
5205 panic!("Expected RetryableError, got {:?}", errors[0]);
5206 }
5207
5208 // Verify the thread is no longer generating
5209 thread.read_with(cx, |thread, _| {
5210 assert!(
5211 !thread.is_generating(),
5212 "Should not be generating after error without retry"
5213 );
5214 });
5215 }
5216
5217 #[gpui::test]
5218 async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
5219 let fs = init_test_settings(cx);
5220
5221 let project = create_test_project(&fs, cx, json!({})).await;
5222 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5223
5224 // Enable Burn Mode to allow retries
5225 thread.update(cx, |thread, _| {
5226 thread.set_completion_mode(CompletionMode::Burn);
5227 });
5228
5229 // Create model that returns overloaded error
5230 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5231
5232 // Insert a user message
5233 thread.update(cx, |thread, cx| {
5234 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5235 });
5236
5237 // Start completion
5238 thread.update(cx, |thread, cx| {
5239 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5240 });
5241
5242 cx.run_until_parked();
5243
5244 // Verify retry was scheduled by checking for retry message
5245 let has_retry_message = thread.read_with(cx, |thread, _| {
5246 thread.messages.iter().any(|m| {
5247 m.ui_only
5248 && m.segments.iter().any(|s| {
5249 if let MessageSegment::Text(text) = s {
5250 text.contains("Retrying") && text.contains("seconds")
5251 } else {
5252 false
5253 }
5254 })
5255 })
5256 });
5257 assert!(has_retry_message, "Should have scheduled a retry");
5258
5259 // Cancel the completion before the retry happens
5260 thread.update(cx, |thread, cx| {
5261 thread.cancel_last_completion(None, cx);
5262 });
5263
5264 cx.run_until_parked();
5265
5266 // The retry should not have happened - no pending completions
5267 let fake_model = model.as_fake();
5268 assert_eq!(
5269 fake_model.pending_completions().len(),
5270 0,
5271 "Should have no pending completions after cancellation"
5272 );
5273
5274 // Verify the retry was canceled by checking retry state
5275 thread.read_with(cx, |thread, _| {
5276 if let Some(retry_state) = &thread.retry_state {
5277 panic!(
5278 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5279 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5280 );
5281 }
5282 });
5283 }
5284
5285 fn test_summarize_error(
5286 model: &Arc<dyn LanguageModel>,
5287 thread: &Entity<Thread>,
5288 cx: &mut TestAppContext,
5289 ) {
5290 thread.update(cx, |thread, cx| {
5291 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5292 thread.send_to_model(
5293 model.clone(),
5294 CompletionIntent::ThreadSummarization,
5295 None,
5296 cx,
5297 );
5298 });
5299
5300 let fake_model = model.as_fake();
5301 simulate_successful_response(fake_model, cx);
5302
5303 thread.read_with(cx, |thread, _| {
5304 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5305 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5306 });
5307
5308 // Simulate summary request ending
5309 cx.run_until_parked();
5310 fake_model.end_last_completion_stream();
5311 cx.run_until_parked();
5312
5313 // State is set to Error and default message
5314 thread.read_with(cx, |thread, _| {
5315 assert!(matches!(thread.summary(), ThreadSummary::Error));
5316 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5317 });
5318 }
5319
5320 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5321 cx.run_until_parked();
5322 fake_model.send_last_completion_stream_text_chunk("Assistant response");
5323 fake_model.end_last_completion_stream();
5324 cx.run_until_parked();
5325 }
5326
5327 fn init_test_settings(cx: &mut TestAppContext) -> Arc<dyn Fs> {
5328 let fs = FakeFs::new(cx.executor());
5329 cx.update(|cx| {
5330 let settings_store = SettingsStore::test(cx);
5331 cx.set_global(settings_store);
5332 language::init(cx);
5333 Project::init_settings(cx);
5334 AgentSettings::register(cx);
5335 prompt_store::init(cx);
5336 thread_store::init(fs.clone(), cx);
5337 workspace::init_settings(cx);
5338 language_model::init_settings(cx);
5339 theme::init(theme::LoadThemes::JustBase, cx);
5340 ToolRegistry::default_global(cx);
5341 assistant_tool::init(cx);
5342
5343 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5344 http_client::FakeHttpClient::with_200_response(),
5345 "http://localhost".to_string(),
5346 None,
5347 ));
5348 assistant_tools::init(http_client, cx);
5349 });
5350 fs
5351 }
5352
5353 // Helper to create a test project with test files
5354 async fn create_test_project(
5355 fs: &Arc<dyn Fs>,
5356 cx: &mut TestAppContext,
5357 files: serde_json::Value,
5358 ) -> Entity<Project> {
5359 fs.as_fake().insert_tree(path!("/test"), files).await;
5360 Project::test(fs.clone(), [path!("/test").as_ref()], cx).await
5361 }
5362
5363 async fn setup_test_environment(
5364 cx: &mut TestAppContext,
5365 project: Entity<Project>,
5366 ) -> (
5367 Entity<Workspace>,
5368 Entity<ThreadStore>,
5369 Entity<Thread>,
5370 Entity<ContextStore>,
5371 Arc<dyn LanguageModel>,
5372 ) {
5373 let (workspace, cx) =
5374 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5375
5376 let thread_store = cx
5377 .update(|_, cx| {
5378 ThreadStore::load(
5379 project.clone(),
5380 cx.new(|_| ToolWorkingSet::default()),
5381 None,
5382 Arc::new(PromptBuilder::new(None).unwrap()),
5383 cx,
5384 )
5385 })
5386 .await
5387 .unwrap();
5388
5389 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5390 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5391
5392 let provider = Arc::new(FakeLanguageModelProvider::default());
5393 let model = provider.test_model();
5394 let model: Arc<dyn LanguageModel> = Arc::new(model);
5395
5396 cx.update(|_, cx| {
5397 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5398 registry.set_default_model(
5399 Some(ConfiguredModel {
5400 provider: provider.clone(),
5401 model: model.clone(),
5402 }),
5403 cx,
5404 );
5405 registry.set_thread_summary_model(
5406 Some(ConfiguredModel {
5407 provider,
5408 model: model.clone(),
5409 }),
5410 cx,
5411 );
5412 })
5413 });
5414
5415 (workspace, thread_store, thread, context_store, model)
5416 }
5417
5418 async fn add_file_to_context(
5419 project: &Entity<Project>,
5420 context_store: &Entity<ContextStore>,
5421 path: &str,
5422 cx: &mut TestAppContext,
5423 ) -> Result<Entity<language::Buffer>> {
5424 let buffer_path = project
5425 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5426 .unwrap();
5427
5428 let buffer = project
5429 .update(cx, |project, cx| {
5430 project.open_buffer(buffer_path.clone(), cx)
5431 })
5432 .await
5433 .unwrap();
5434
5435 context_store.update(cx, |context_store, cx| {
5436 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5437 });
5438
5439 Ok(buffer)
5440 }
5441}