1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use action_log::ActionLog;
12use agent_settings::{
13 AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
14 SUMMARIZE_THREAD_PROMPT,
15};
16use anyhow::{Result, anyhow};
17use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
18use chrono::{DateTime, Utc};
19use client::{ModelRequestUsage, RequestUsage};
20use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
21use collections::HashMap;
22use futures::{FutureExt, StreamExt as _, future::Shared};
23use git::repository::DiffType;
24use gpui::{
25 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
26 WeakEntity, Window,
27};
28use http_client::StatusCode;
29use language_model::{
30 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
31 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
32 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
33 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
34 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
35 TokenUsage,
36};
37use postage::stream::Stream as _;
38use project::{
39 Project,
40 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
41};
42use prompt_store::{ModelContext, PromptBuilder};
43use schemars::JsonSchema;
44use serde::{Deserialize, Serialize};
45use settings::Settings;
46use std::{
47 io::Write,
48 ops::Range,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use thiserror::Error;
53use util::{ResultExt as _, post_inc};
54use uuid::Uuid;
55
56const MAX_RETRY_ATTEMPTS: u8 = 4;
57const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
58
59#[derive(Debug, Clone)]
60enum RetryStrategy {
61 ExponentialBackoff {
62 initial_delay: Duration,
63 max_attempts: u8,
64 },
65 Fixed {
66 delay: Duration,
67 max_attempts: u8,
68 },
69}
70
71#[derive(
72 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
73)]
74pub struct ThreadId(Arc<str>);
75
76impl ThreadId {
77 pub fn new() -> Self {
78 Self(Uuid::new_v4().to_string().into())
79 }
80}
81
82impl std::fmt::Display for ThreadId {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.0)
85 }
86}
87
88impl From<&str> for ThreadId {
89 fn from(value: &str) -> Self {
90 Self(value.into())
91 }
92}
93
94/// The ID of the user prompt that initiated a request.
95///
96/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
98pub struct PromptId(Arc<str>);
99
100impl PromptId {
101 pub fn new() -> Self {
102 Self(Uuid::new_v4().to_string().into())
103 }
104}
105
106impl std::fmt::Display for PromptId {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
113pub struct MessageId(pub usize);
114
115impl MessageId {
116 fn post_inc(&mut self) -> Self {
117 Self(post_inc(&mut self.0))
118 }
119
120 pub fn as_usize(&self) -> usize {
121 self.0
122 }
123}
124
125/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
126#[derive(Clone, Debug)]
127pub struct MessageCrease {
128 pub range: Range<usize>,
129 pub icon_path: SharedString,
130 pub label: SharedString,
131 /// None for a deserialized message, Some otherwise.
132 pub context: Option<AgentContextHandle>,
133}
134
135/// A message in a [`Thread`].
136#[derive(Debug, Clone)]
137pub struct Message {
138 pub id: MessageId,
139 pub role: Role,
140 pub segments: Vec<MessageSegment>,
141 pub loaded_context: LoadedContext,
142 pub creases: Vec<MessageCrease>,
143 pub is_hidden: bool,
144 pub ui_only: bool,
145}
146
147impl Message {
148 /// Returns whether the message contains any meaningful text that should be displayed
149 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
150 pub fn should_display_content(&self) -> bool {
151 self.segments.iter().all(|segment| segment.should_display())
152 }
153
154 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
155 if let Some(MessageSegment::Thinking {
156 text: segment,
157 signature: current_signature,
158 }) = self.segments.last_mut()
159 {
160 if let Some(signature) = signature {
161 *current_signature = Some(signature);
162 }
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Thinking {
166 text: text.to_string(),
167 signature,
168 });
169 }
170 }
171
172 pub fn push_redacted_thinking(&mut self, data: String) {
173 self.segments.push(MessageSegment::RedactedThinking(data));
174 }
175
176 pub fn push_text(&mut self, text: &str) {
177 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
178 segment.push_str(text);
179 } else {
180 self.segments.push(MessageSegment::Text(text.to_string()));
181 }
182 }
183
184 pub fn to_message_content(&self) -> String {
185 let mut result = String::new();
186
187 if !self.loaded_context.text.is_empty() {
188 result.push_str(&self.loaded_context.text);
189 }
190
191 for segment in &self.segments {
192 match segment {
193 MessageSegment::Text(text) => result.push_str(text),
194 MessageSegment::Thinking { text, .. } => {
195 result.push_str("<think>\n");
196 result.push_str(text);
197 result.push_str("\n</think>");
198 }
199 MessageSegment::RedactedThinking(_) => {}
200 }
201 }
202
203 result
204 }
205}
206
207#[derive(Debug, Clone, PartialEq, Eq)]
208pub enum MessageSegment {
209 Text(String),
210 Thinking {
211 text: String,
212 signature: Option<String>,
213 },
214 RedactedThinking(String),
215}
216
217impl MessageSegment {
218 pub fn should_display(&self) -> bool {
219 match self {
220 Self::Text(text) => text.is_empty(),
221 Self::Thinking { text, .. } => text.is_empty(),
222 Self::RedactedThinking(_) => false,
223 }
224 }
225
226 pub fn text(&self) -> Option<&str> {
227 match self {
228 MessageSegment::Text(text) => Some(text),
229 _ => None,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
235pub struct ProjectSnapshot {
236 pub worktree_snapshots: Vec<WorktreeSnapshot>,
237 pub timestamp: DateTime<Utc>,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
241pub struct WorktreeSnapshot {
242 pub worktree_path: String,
243 pub git_state: Option<GitState>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
247pub struct GitState {
248 pub remote_url: Option<String>,
249 pub head_sha: Option<String>,
250 pub current_branch: Option<String>,
251 pub diff: Option<String>,
252}
253
254#[derive(Clone, Debug)]
255pub struct ThreadCheckpoint {
256 message_id: MessageId,
257 git_checkpoint: GitStoreCheckpoint,
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261pub enum ThreadFeedback {
262 Positive,
263 Negative,
264}
265
266pub enum LastRestoreCheckpoint {
267 Pending {
268 message_id: MessageId,
269 },
270 Error {
271 message_id: MessageId,
272 error: String,
273 },
274}
275
276impl LastRestoreCheckpoint {
277 pub fn message_id(&self) -> MessageId {
278 match self {
279 LastRestoreCheckpoint::Pending { message_id } => *message_id,
280 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
281 }
282 }
283}
284
285#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
286pub enum DetailedSummaryState {
287 #[default]
288 NotGenerated,
289 Generating {
290 message_id: MessageId,
291 },
292 Generated {
293 text: SharedString,
294 message_id: MessageId,
295 },
296}
297
298impl DetailedSummaryState {
299 fn text(&self) -> Option<SharedString> {
300 if let Self::Generated { text, .. } = self {
301 Some(text.clone())
302 } else {
303 None
304 }
305 }
306}
307
308#[derive(Default, Debug)]
309pub struct TotalTokenUsage {
310 pub total: u64,
311 pub max: u64,
312}
313
314impl TotalTokenUsage {
315 pub fn ratio(&self) -> TokenUsageRatio {
316 #[cfg(debug_assertions)]
317 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
318 .unwrap_or("0.8".to_string())
319 .parse()
320 .unwrap();
321 #[cfg(not(debug_assertions))]
322 let warning_threshold: f32 = 0.8;
323
324 // When the maximum is unknown because there is no selected model,
325 // avoid showing the token limit warning.
326 if self.max == 0 {
327 TokenUsageRatio::Normal
328 } else if self.total >= self.max {
329 TokenUsageRatio::Exceeded
330 } else if self.total as f32 / self.max as f32 >= warning_threshold {
331 TokenUsageRatio::Warning
332 } else {
333 TokenUsageRatio::Normal
334 }
335 }
336
337 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
338 TotalTokenUsage {
339 total: self.total + tokens,
340 max: self.max,
341 }
342 }
343}
344
345#[derive(Debug, Default, PartialEq, Eq)]
346pub enum TokenUsageRatio {
347 #[default]
348 Normal,
349 Warning,
350 Exceeded,
351}
352
353#[derive(Debug, Clone, Copy)]
354pub enum QueueState {
355 Sending,
356 Queued { position: usize },
357 Started,
358}
359
360/// A thread of conversation with the LLM.
361pub struct Thread {
362 id: ThreadId,
363 updated_at: DateTime<Utc>,
364 summary: ThreadSummary,
365 pending_summary: Task<Option<()>>,
366 detailed_summary_task: Task<Option<()>>,
367 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
368 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
369 completion_mode: agent_settings::CompletionMode,
370 messages: Vec<Message>,
371 next_message_id: MessageId,
372 last_prompt_id: PromptId,
373 project_context: SharedProjectContext,
374 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
375 completion_count: usize,
376 pending_completions: Vec<PendingCompletion>,
377 project: Entity<Project>,
378 prompt_builder: Arc<PromptBuilder>,
379 tools: Entity<ToolWorkingSet>,
380 tool_use: ToolUseState,
381 action_log: Entity<ActionLog>,
382 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
383 pending_checkpoint: Option<ThreadCheckpoint>,
384 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
385 request_token_usage: Vec<TokenUsage>,
386 cumulative_token_usage: TokenUsage,
387 exceeded_window_error: Option<ExceededWindowError>,
388 tool_use_limit_reached: bool,
389 retry_state: Option<RetryState>,
390 message_feedback: HashMap<MessageId, ThreadFeedback>,
391 last_received_chunk_at: Option<Instant>,
392 request_callback: Option<
393 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
394 >,
395 remaining_turns: u32,
396 configured_model: Option<ConfiguredModel>,
397 profile: AgentProfile,
398 last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
399}
400
401#[derive(Clone, Debug)]
402struct RetryState {
403 attempt: u8,
404 max_attempts: u8,
405 intent: CompletionIntent,
406}
407
408#[derive(Clone, Debug, PartialEq, Eq)]
409pub enum ThreadSummary {
410 Pending,
411 Generating,
412 Ready(SharedString),
413 Error,
414}
415
416impl ThreadSummary {
417 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
418
419 pub fn or_default(&self) -> SharedString {
420 self.unwrap_or(Self::DEFAULT)
421 }
422
423 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
424 self.ready().unwrap_or_else(|| message.into())
425 }
426
427 pub fn ready(&self) -> Option<SharedString> {
428 match self {
429 ThreadSummary::Ready(summary) => Some(summary.clone()),
430 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
436pub struct ExceededWindowError {
437 /// Model used when last message exceeded context window
438 model_id: LanguageModelId,
439 /// Token count including last message
440 token_count: u64,
441}
442
443impl Thread {
444 pub fn new(
445 project: Entity<Project>,
446 tools: Entity<ToolWorkingSet>,
447 prompt_builder: Arc<PromptBuilder>,
448 system_prompt: SharedProjectContext,
449 cx: &mut Context<Self>,
450 ) -> Self {
451 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
452 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
453 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
454
455 Self {
456 id: ThreadId::new(),
457 updated_at: Utc::now(),
458 summary: ThreadSummary::Pending,
459 pending_summary: Task::ready(None),
460 detailed_summary_task: Task::ready(None),
461 detailed_summary_tx,
462 detailed_summary_rx,
463 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
464 messages: Vec::new(),
465 next_message_id: MessageId(0),
466 last_prompt_id: PromptId::new(),
467 project_context: system_prompt,
468 checkpoints_by_message: HashMap::default(),
469 completion_count: 0,
470 pending_completions: Vec::new(),
471 project: project.clone(),
472 prompt_builder,
473 tools: tools.clone(),
474 last_restore_checkpoint: None,
475 pending_checkpoint: None,
476 tool_use: ToolUseState::new(tools.clone()),
477 action_log: cx.new(|_| ActionLog::new(project.clone())),
478 initial_project_snapshot: {
479 let project_snapshot = Self::project_snapshot(project, cx);
480 cx.foreground_executor()
481 .spawn(async move { Some(project_snapshot.await) })
482 .shared()
483 },
484 request_token_usage: Vec::new(),
485 cumulative_token_usage: TokenUsage::default(),
486 exceeded_window_error: None,
487 tool_use_limit_reached: false,
488 retry_state: None,
489 message_feedback: HashMap::default(),
490 last_error_context: None,
491 last_received_chunk_at: None,
492 request_callback: None,
493 remaining_turns: u32::MAX,
494 configured_model,
495 profile: AgentProfile::new(profile_id, tools),
496 }
497 }
498
499 pub fn deserialize(
500 id: ThreadId,
501 serialized: SerializedThread,
502 project: Entity<Project>,
503 tools: Entity<ToolWorkingSet>,
504 prompt_builder: Arc<PromptBuilder>,
505 project_context: SharedProjectContext,
506 window: Option<&mut Window>, // None in headless mode
507 cx: &mut Context<Self>,
508 ) -> Self {
509 let next_message_id = MessageId(
510 serialized
511 .messages
512 .last()
513 .map(|message| message.id.0 + 1)
514 .unwrap_or(0),
515 );
516 let tool_use = ToolUseState::from_serialized_messages(
517 tools.clone(),
518 &serialized.messages,
519 project.clone(),
520 window,
521 cx,
522 );
523 let (detailed_summary_tx, detailed_summary_rx) =
524 postage::watch::channel_with(serialized.detailed_summary_state);
525
526 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
527 serialized
528 .model
529 .and_then(|model| {
530 let model = SelectedModel {
531 provider: model.provider.clone().into(),
532 model: model.model.into(),
533 };
534 registry.select_model(&model, cx)
535 })
536 .or_else(|| registry.default_model())
537 });
538
539 let completion_mode = serialized
540 .completion_mode
541 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
542 let profile_id = serialized
543 .profile
544 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
545
546 Self {
547 id,
548 updated_at: serialized.updated_at,
549 summary: ThreadSummary::Ready(serialized.summary),
550 pending_summary: Task::ready(None),
551 detailed_summary_task: Task::ready(None),
552 detailed_summary_tx,
553 detailed_summary_rx,
554 completion_mode,
555 retry_state: None,
556 messages: serialized
557 .messages
558 .into_iter()
559 .map(|message| Message {
560 id: message.id,
561 role: message.role,
562 segments: message
563 .segments
564 .into_iter()
565 .map(|segment| match segment {
566 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
567 SerializedMessageSegment::Thinking { text, signature } => {
568 MessageSegment::Thinking { text, signature }
569 }
570 SerializedMessageSegment::RedactedThinking { data } => {
571 MessageSegment::RedactedThinking(data)
572 }
573 })
574 .collect(),
575 loaded_context: LoadedContext {
576 contexts: Vec::new(),
577 text: message.context,
578 images: Vec::new(),
579 },
580 creases: message
581 .creases
582 .into_iter()
583 .map(|crease| MessageCrease {
584 range: crease.start..crease.end,
585 icon_path: crease.icon_path,
586 label: crease.label,
587 context: None,
588 })
589 .collect(),
590 is_hidden: message.is_hidden,
591 ui_only: false, // UI-only messages are not persisted
592 })
593 .collect(),
594 next_message_id,
595 last_prompt_id: PromptId::new(),
596 project_context,
597 checkpoints_by_message: HashMap::default(),
598 completion_count: 0,
599 pending_completions: Vec::new(),
600 last_restore_checkpoint: None,
601 pending_checkpoint: None,
602 project: project.clone(),
603 prompt_builder,
604 tools: tools.clone(),
605 tool_use,
606 action_log: cx.new(|_| ActionLog::new(project)),
607 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
608 request_token_usage: serialized.request_token_usage,
609 cumulative_token_usage: serialized.cumulative_token_usage,
610 exceeded_window_error: None,
611 tool_use_limit_reached: serialized.tool_use_limit_reached,
612 message_feedback: HashMap::default(),
613 last_error_context: None,
614 last_received_chunk_at: None,
615 request_callback: None,
616 remaining_turns: u32::MAX,
617 configured_model,
618 profile: AgentProfile::new(profile_id, tools),
619 }
620 }
621
622 pub fn set_request_callback(
623 &mut self,
624 callback: impl 'static
625 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
626 ) {
627 self.request_callback = Some(Box::new(callback));
628 }
629
630 pub fn id(&self) -> &ThreadId {
631 &self.id
632 }
633
634 pub fn profile(&self) -> &AgentProfile {
635 &self.profile
636 }
637
638 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
639 if &id != self.profile.id() {
640 self.profile = AgentProfile::new(id, self.tools.clone());
641 cx.emit(ThreadEvent::ProfileChanged);
642 }
643 }
644
645 pub fn is_empty(&self) -> bool {
646 self.messages.is_empty()
647 }
648
649 pub fn updated_at(&self) -> DateTime<Utc> {
650 self.updated_at
651 }
652
653 pub fn touch_updated_at(&mut self) {
654 self.updated_at = Utc::now();
655 }
656
657 pub fn advance_prompt_id(&mut self) {
658 self.last_prompt_id = PromptId::new();
659 }
660
661 pub fn project_context(&self) -> SharedProjectContext {
662 self.project_context.clone()
663 }
664
665 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
666 if self.configured_model.is_none() {
667 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
668 }
669 self.configured_model.clone()
670 }
671
672 pub fn configured_model(&self) -> Option<ConfiguredModel> {
673 self.configured_model.clone()
674 }
675
676 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
677 self.configured_model = model;
678 cx.notify();
679 }
680
681 pub fn summary(&self) -> &ThreadSummary {
682 &self.summary
683 }
684
685 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
686 let current_summary = match &self.summary {
687 ThreadSummary::Pending | ThreadSummary::Generating => return,
688 ThreadSummary::Ready(summary) => summary,
689 ThreadSummary::Error => &ThreadSummary::DEFAULT,
690 };
691
692 let mut new_summary = new_summary.into();
693
694 if new_summary.is_empty() {
695 new_summary = ThreadSummary::DEFAULT;
696 }
697
698 if current_summary != &new_summary {
699 self.summary = ThreadSummary::Ready(new_summary);
700 cx.emit(ThreadEvent::SummaryChanged);
701 }
702 }
703
704 pub fn completion_mode(&self) -> CompletionMode {
705 self.completion_mode
706 }
707
708 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
709 self.completion_mode = mode;
710 }
711
712 pub fn message(&self, id: MessageId) -> Option<&Message> {
713 let index = self
714 .messages
715 .binary_search_by(|message| message.id.cmp(&id))
716 .ok()?;
717
718 self.messages.get(index)
719 }
720
721 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
722 self.messages.iter()
723 }
724
725 pub fn is_generating(&self) -> bool {
726 !self.pending_completions.is_empty() || !self.all_tools_finished()
727 }
728
729 /// Indicates whether streaming of language model events is stale.
730 /// When `is_generating()` is false, this method returns `None`.
731 pub fn is_generation_stale(&self) -> Option<bool> {
732 const STALE_THRESHOLD: u128 = 250;
733
734 self.last_received_chunk_at
735 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
736 }
737
738 fn received_chunk(&mut self) {
739 self.last_received_chunk_at = Some(Instant::now());
740 }
741
742 pub fn queue_state(&self) -> Option<QueueState> {
743 self.pending_completions
744 .first()
745 .map(|pending_completion| pending_completion.queue_state)
746 }
747
748 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
749 &self.tools
750 }
751
752 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
753 self.tool_use
754 .pending_tool_uses()
755 .into_iter()
756 .find(|tool_use| &tool_use.id == id)
757 }
758
759 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
760 self.tool_use
761 .pending_tool_uses()
762 .into_iter()
763 .filter(|tool_use| tool_use.status.needs_confirmation())
764 }
765
766 pub fn has_pending_tool_uses(&self) -> bool {
767 !self.tool_use.pending_tool_uses().is_empty()
768 }
769
770 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
771 self.checkpoints_by_message.get(&id).cloned()
772 }
773
774 pub fn restore_checkpoint(
775 &mut self,
776 checkpoint: ThreadCheckpoint,
777 cx: &mut Context<Self>,
778 ) -> Task<Result<()>> {
779 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
780 message_id: checkpoint.message_id,
781 });
782 cx.emit(ThreadEvent::CheckpointChanged);
783 cx.notify();
784
785 let git_store = self.project().read(cx).git_store().clone();
786 let restore = git_store.update(cx, |git_store, cx| {
787 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
788 });
789
790 cx.spawn(async move |this, cx| {
791 let result = restore.await;
792 this.update(cx, |this, cx| {
793 if let Err(err) = result.as_ref() {
794 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
795 message_id: checkpoint.message_id,
796 error: err.to_string(),
797 });
798 } else {
799 this.truncate(checkpoint.message_id, cx);
800 this.last_restore_checkpoint = None;
801 }
802 this.pending_checkpoint = None;
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 })?;
806 result
807 })
808 }
809
810 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
811 let pending_checkpoint = if self.is_generating() {
812 return;
813 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
814 checkpoint
815 } else {
816 return;
817 };
818
819 self.finalize_checkpoint(pending_checkpoint, cx);
820 }
821
822 fn finalize_checkpoint(
823 &mut self,
824 pending_checkpoint: ThreadCheckpoint,
825 cx: &mut Context<Self>,
826 ) {
827 let git_store = self.project.read(cx).git_store().clone();
828 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
829 cx.spawn(async move |this, cx| match final_checkpoint.await {
830 Ok(final_checkpoint) => {
831 let equal = git_store
832 .update(cx, |store, cx| {
833 store.compare_checkpoints(
834 pending_checkpoint.git_checkpoint.clone(),
835 final_checkpoint.clone(),
836 cx,
837 )
838 })?
839 .await
840 .unwrap_or(false);
841
842 this.update(cx, |this, cx| {
843 this.pending_checkpoint = if equal {
844 Some(pending_checkpoint)
845 } else {
846 this.insert_checkpoint(pending_checkpoint, cx);
847 Some(ThreadCheckpoint {
848 message_id: this.next_message_id,
849 git_checkpoint: final_checkpoint,
850 })
851 }
852 })?;
853
854 Ok(())
855 }
856 Err(_) => this.update(cx, |this, cx| {
857 this.insert_checkpoint(pending_checkpoint, cx)
858 }),
859 })
860 .detach();
861 }
862
863 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
864 self.checkpoints_by_message
865 .insert(checkpoint.message_id, checkpoint);
866 cx.emit(ThreadEvent::CheckpointChanged);
867 cx.notify();
868 }
869
870 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
871 self.last_restore_checkpoint.as_ref()
872 }
873
874 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
875 let Some(message_ix) = self
876 .messages
877 .iter()
878 .rposition(|message| message.id == message_id)
879 else {
880 return;
881 };
882 for deleted_message in self.messages.drain(message_ix..) {
883 self.checkpoints_by_message.remove(&deleted_message.id);
884 }
885 cx.notify();
886 }
887
888 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
889 self.messages
890 .iter()
891 .find(|message| message.id == id)
892 .into_iter()
893 .flat_map(|message| message.loaded_context.contexts.iter())
894 }
895
896 pub fn is_turn_end(&self, ix: usize) -> bool {
897 if self.messages.is_empty() {
898 return false;
899 }
900
901 if !self.is_generating() && ix == self.messages.len() - 1 {
902 return true;
903 }
904
905 let Some(message) = self.messages.get(ix) else {
906 return false;
907 };
908
909 if message.role != Role::Assistant {
910 return false;
911 }
912
913 self.messages
914 .get(ix + 1)
915 .and_then(|message| {
916 self.message(message.id)
917 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
918 })
919 .unwrap_or(false)
920 }
921
922 pub fn tool_use_limit_reached(&self) -> bool {
923 self.tool_use_limit_reached
924 }
925
926 /// Returns whether all of the tool uses have finished running.
927 pub fn all_tools_finished(&self) -> bool {
928 // If the only pending tool uses left are the ones with errors, then
929 // that means that we've finished running all of the pending tools.
930 self.tool_use
931 .pending_tool_uses()
932 .iter()
933 .all(|pending_tool_use| pending_tool_use.status.is_error())
934 }
935
936 /// Returns whether any pending tool uses may perform edits
937 pub fn has_pending_edit_tool_uses(&self) -> bool {
938 self.tool_use
939 .pending_tool_uses()
940 .iter()
941 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
942 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
943 }
944
945 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
946 self.tool_use.tool_uses_for_message(id, &self.project, cx)
947 }
948
949 pub fn tool_results_for_message(
950 &self,
951 assistant_message_id: MessageId,
952 ) -> Vec<&LanguageModelToolResult> {
953 self.tool_use.tool_results_for_message(assistant_message_id)
954 }
955
956 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
957 self.tool_use.tool_result(id)
958 }
959
960 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
961 match &self.tool_use.tool_result(id)?.content {
962 LanguageModelToolResultContent::Text(text) => Some(text),
963 LanguageModelToolResultContent::Image(_) => {
964 // TODO: We should display image
965 None
966 }
967 }
968 }
969
970 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
971 self.tool_use.tool_result_card(id).cloned()
972 }
973
974 /// Return tools that are both enabled and supported by the model
975 pub fn available_tools(
976 &self,
977 cx: &App,
978 model: Arc<dyn LanguageModel>,
979 ) -> Vec<LanguageModelRequestTool> {
980 if model.supports_tools() {
981 self.profile
982 .enabled_tools(cx)
983 .into_iter()
984 .filter_map(|(name, tool)| {
985 // Skip tools that cannot be supported
986 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
987 Some(LanguageModelRequestTool {
988 name: name.into(),
989 description: tool.description(),
990 input_schema,
991 })
992 })
993 .collect()
994 } else {
995 Vec::default()
996 }
997 }
998
999 pub fn insert_user_message(
1000 &mut self,
1001 text: impl Into<String>,
1002 loaded_context: ContextLoadResult,
1003 git_checkpoint: Option<GitStoreCheckpoint>,
1004 creases: Vec<MessageCrease>,
1005 cx: &mut Context<Self>,
1006 ) -> MessageId {
1007 if !loaded_context.referenced_buffers.is_empty() {
1008 self.action_log.update(cx, |log, cx| {
1009 for buffer in loaded_context.referenced_buffers {
1010 log.buffer_read(buffer, cx);
1011 }
1012 });
1013 }
1014
1015 let message_id = self.insert_message(
1016 Role::User,
1017 vec![MessageSegment::Text(text.into())],
1018 loaded_context.loaded_context,
1019 creases,
1020 false,
1021 cx,
1022 );
1023
1024 if let Some(git_checkpoint) = git_checkpoint {
1025 self.pending_checkpoint = Some(ThreadCheckpoint {
1026 message_id,
1027 git_checkpoint,
1028 });
1029 }
1030
1031 message_id
1032 }
1033
1034 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1035 let id = self.insert_message(
1036 Role::User,
1037 vec![MessageSegment::Text("Continue where you left off".into())],
1038 LoadedContext::default(),
1039 vec![],
1040 true,
1041 cx,
1042 );
1043 self.pending_checkpoint = None;
1044
1045 id
1046 }
1047
1048 pub fn insert_assistant_message(
1049 &mut self,
1050 segments: Vec<MessageSegment>,
1051 cx: &mut Context<Self>,
1052 ) -> MessageId {
1053 self.insert_message(
1054 Role::Assistant,
1055 segments,
1056 LoadedContext::default(),
1057 Vec::new(),
1058 false,
1059 cx,
1060 )
1061 }
1062
1063 pub fn insert_message(
1064 &mut self,
1065 role: Role,
1066 segments: Vec<MessageSegment>,
1067 loaded_context: LoadedContext,
1068 creases: Vec<MessageCrease>,
1069 is_hidden: bool,
1070 cx: &mut Context<Self>,
1071 ) -> MessageId {
1072 let id = self.next_message_id.post_inc();
1073 self.messages.push(Message {
1074 id,
1075 role,
1076 segments,
1077 loaded_context,
1078 creases,
1079 is_hidden,
1080 ui_only: false,
1081 });
1082 self.touch_updated_at();
1083 cx.emit(ThreadEvent::MessageAdded(id));
1084 id
1085 }
1086
1087 pub fn edit_message(
1088 &mut self,
1089 id: MessageId,
1090 new_role: Role,
1091 new_segments: Vec<MessageSegment>,
1092 creases: Vec<MessageCrease>,
1093 loaded_context: Option<LoadedContext>,
1094 checkpoint: Option<GitStoreCheckpoint>,
1095 cx: &mut Context<Self>,
1096 ) -> bool {
1097 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1098 return false;
1099 };
1100 message.role = new_role;
1101 message.segments = new_segments;
1102 message.creases = creases;
1103 if let Some(context) = loaded_context {
1104 message.loaded_context = context;
1105 }
1106 if let Some(git_checkpoint) = checkpoint {
1107 self.checkpoints_by_message.insert(
1108 id,
1109 ThreadCheckpoint {
1110 message_id: id,
1111 git_checkpoint,
1112 },
1113 );
1114 }
1115 self.touch_updated_at();
1116 cx.emit(ThreadEvent::MessageEdited(id));
1117 true
1118 }
1119
1120 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1121 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1122 return false;
1123 };
1124 self.messages.remove(index);
1125 self.touch_updated_at();
1126 cx.emit(ThreadEvent::MessageDeleted(id));
1127 true
1128 }
1129
1130 /// Returns the representation of this [`Thread`] in a textual form.
1131 ///
1132 /// This is the representation we use when attaching a thread as context to another thread.
1133 pub fn text(&self) -> String {
1134 let mut text = String::new();
1135
1136 for message in &self.messages {
1137 text.push_str(match message.role {
1138 language_model::Role::User => "User:",
1139 language_model::Role::Assistant => "Agent:",
1140 language_model::Role::System => "System:",
1141 });
1142 text.push('\n');
1143
1144 for segment in &message.segments {
1145 match segment {
1146 MessageSegment::Text(content) => text.push_str(content),
1147 MessageSegment::Thinking { text: content, .. } => {
1148 text.push_str(&format!("<think>{}</think>", content))
1149 }
1150 MessageSegment::RedactedThinking(_) => {}
1151 }
1152 }
1153 text.push('\n');
1154 }
1155
1156 text
1157 }
1158
1159 /// Serializes this thread into a format for storage or telemetry.
1160 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1161 let initial_project_snapshot = self.initial_project_snapshot.clone();
1162 cx.spawn(async move |this, cx| {
1163 let initial_project_snapshot = initial_project_snapshot.await;
1164 this.read_with(cx, |this, cx| SerializedThread {
1165 version: SerializedThread::VERSION.to_string(),
1166 summary: this.summary().or_default(),
1167 updated_at: this.updated_at(),
1168 messages: this
1169 .messages()
1170 .filter(|message| !message.ui_only)
1171 .map(|message| SerializedMessage {
1172 id: message.id,
1173 role: message.role,
1174 segments: message
1175 .segments
1176 .iter()
1177 .map(|segment| match segment {
1178 MessageSegment::Text(text) => {
1179 SerializedMessageSegment::Text { text: text.clone() }
1180 }
1181 MessageSegment::Thinking { text, signature } => {
1182 SerializedMessageSegment::Thinking {
1183 text: text.clone(),
1184 signature: signature.clone(),
1185 }
1186 }
1187 MessageSegment::RedactedThinking(data) => {
1188 SerializedMessageSegment::RedactedThinking {
1189 data: data.clone(),
1190 }
1191 }
1192 })
1193 .collect(),
1194 tool_uses: this
1195 .tool_uses_for_message(message.id, cx)
1196 .into_iter()
1197 .map(|tool_use| SerializedToolUse {
1198 id: tool_use.id,
1199 name: tool_use.name,
1200 input: tool_use.input,
1201 })
1202 .collect(),
1203 tool_results: this
1204 .tool_results_for_message(message.id)
1205 .into_iter()
1206 .map(|tool_result| SerializedToolResult {
1207 tool_use_id: tool_result.tool_use_id.clone(),
1208 is_error: tool_result.is_error,
1209 content: tool_result.content.clone(),
1210 output: tool_result.output.clone(),
1211 })
1212 .collect(),
1213 context: message.loaded_context.text.clone(),
1214 creases: message
1215 .creases
1216 .iter()
1217 .map(|crease| SerializedCrease {
1218 start: crease.range.start,
1219 end: crease.range.end,
1220 icon_path: crease.icon_path.clone(),
1221 label: crease.label.clone(),
1222 })
1223 .collect(),
1224 is_hidden: message.is_hidden,
1225 })
1226 .collect(),
1227 initial_project_snapshot,
1228 cumulative_token_usage: this.cumulative_token_usage,
1229 request_token_usage: this.request_token_usage.clone(),
1230 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1231 exceeded_window_error: this.exceeded_window_error.clone(),
1232 model: this
1233 .configured_model
1234 .as_ref()
1235 .map(|model| SerializedLanguageModel {
1236 provider: model.provider.id().0.to_string(),
1237 model: model.model.id().0.to_string(),
1238 }),
1239 completion_mode: Some(this.completion_mode),
1240 tool_use_limit_reached: this.tool_use_limit_reached,
1241 profile: Some(this.profile.id().clone()),
1242 })
1243 })
1244 }
1245
1246 pub fn remaining_turns(&self) -> u32 {
1247 self.remaining_turns
1248 }
1249
1250 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1251 self.remaining_turns = remaining_turns;
1252 }
1253
1254 pub fn send_to_model(
1255 &mut self,
1256 model: Arc<dyn LanguageModel>,
1257 intent: CompletionIntent,
1258 window: Option<AnyWindowHandle>,
1259 cx: &mut Context<Self>,
1260 ) {
1261 if self.remaining_turns == 0 {
1262 return;
1263 }
1264
1265 self.remaining_turns -= 1;
1266
1267 self.flush_notifications(model.clone(), intent, cx);
1268
1269 let _checkpoint = self.finalize_pending_checkpoint(cx);
1270 self.stream_completion(
1271 self.to_completion_request(model.clone(), intent, cx),
1272 model,
1273 intent,
1274 window,
1275 cx,
1276 );
1277 }
1278
1279 pub fn to_completion_request(
1280 &self,
1281 model: Arc<dyn LanguageModel>,
1282 intent: CompletionIntent,
1283 cx: &mut Context<Self>,
1284 ) -> LanguageModelRequest {
1285 let mut request = LanguageModelRequest {
1286 thread_id: Some(self.id.to_string()),
1287 prompt_id: Some(self.last_prompt_id.to_string()),
1288 intent: Some(intent),
1289 mode: None,
1290 messages: vec![],
1291 tools: Vec::new(),
1292 tool_choice: None,
1293 stop: Vec::new(),
1294 temperature: AgentSettings::temperature_for_model(&model, cx),
1295 thinking_allowed: true,
1296 };
1297
1298 let available_tools = self.available_tools(cx, model.clone());
1299 let available_tool_names = available_tools
1300 .iter()
1301 .map(|tool| tool.name.clone())
1302 .collect();
1303
1304 let model_context = &ModelContext {
1305 available_tools: available_tool_names,
1306 };
1307
1308 if let Some(project_context) = self.project_context.borrow().as_ref() {
1309 match self
1310 .prompt_builder
1311 .generate_assistant_system_prompt(project_context, model_context)
1312 {
1313 Err(err) => {
1314 let message = format!("{err:?}").into();
1315 log::error!("{message}");
1316 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1317 header: "Error generating system prompt".into(),
1318 message,
1319 }));
1320 }
1321 Ok(system_prompt) => {
1322 request.messages.push(LanguageModelRequestMessage {
1323 role: Role::System,
1324 content: vec![MessageContent::Text(system_prompt)],
1325 cache: true,
1326 });
1327 }
1328 }
1329 } else {
1330 let message = "Context for system prompt unexpectedly not ready.".into();
1331 log::error!("{message}");
1332 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1333 header: "Error generating system prompt".into(),
1334 message,
1335 }));
1336 }
1337
1338 let mut message_ix_to_cache = None;
1339 for message in &self.messages {
1340 // ui_only messages are for the UI only, not for the model
1341 if message.ui_only {
1342 continue;
1343 }
1344
1345 let mut request_message = LanguageModelRequestMessage {
1346 role: message.role,
1347 content: Vec::new(),
1348 cache: false,
1349 };
1350
1351 message
1352 .loaded_context
1353 .add_to_request_message(&mut request_message);
1354
1355 for segment in &message.segments {
1356 match segment {
1357 MessageSegment::Text(text) => {
1358 let text = text.trim_end();
1359 if !text.is_empty() {
1360 request_message
1361 .content
1362 .push(MessageContent::Text(text.into()));
1363 }
1364 }
1365 MessageSegment::Thinking { text, signature } => {
1366 if !text.is_empty() {
1367 request_message.content.push(MessageContent::Thinking {
1368 text: text.into(),
1369 signature: signature.clone(),
1370 });
1371 }
1372 }
1373 MessageSegment::RedactedThinking(data) => {
1374 request_message
1375 .content
1376 .push(MessageContent::RedactedThinking(data.clone()));
1377 }
1378 };
1379 }
1380
1381 let mut cache_message = true;
1382 let mut tool_results_message = LanguageModelRequestMessage {
1383 role: Role::User,
1384 content: Vec::new(),
1385 cache: false,
1386 };
1387 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1388 if let Some(tool_result) = tool_result {
1389 request_message
1390 .content
1391 .push(MessageContent::ToolUse(tool_use.clone()));
1392 tool_results_message
1393 .content
1394 .push(MessageContent::ToolResult(LanguageModelToolResult {
1395 tool_use_id: tool_use.id.clone(),
1396 tool_name: tool_result.tool_name.clone(),
1397 is_error: tool_result.is_error,
1398 content: if tool_result.content.is_empty() {
1399 // Surprisingly, the API fails if we return an empty string here.
1400 // It thinks we are sending a tool use without a tool result.
1401 "<Tool returned an empty string>".into()
1402 } else {
1403 tool_result.content.clone()
1404 },
1405 output: None,
1406 }));
1407 } else {
1408 cache_message = false;
1409 log::debug!(
1410 "skipped tool use {:?} because it is still pending",
1411 tool_use
1412 );
1413 }
1414 }
1415
1416 if cache_message {
1417 message_ix_to_cache = Some(request.messages.len());
1418 }
1419 request.messages.push(request_message);
1420
1421 if !tool_results_message.content.is_empty() {
1422 if cache_message {
1423 message_ix_to_cache = Some(request.messages.len());
1424 }
1425 request.messages.push(tool_results_message);
1426 }
1427 }
1428
1429 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1430 if let Some(message_ix_to_cache) = message_ix_to_cache {
1431 request.messages[message_ix_to_cache].cache = true;
1432 }
1433
1434 request.tools = available_tools;
1435 request.mode = if model.supports_burn_mode() {
1436 Some(self.completion_mode.into())
1437 } else {
1438 Some(CompletionMode::Normal.into())
1439 };
1440
1441 request
1442 }
1443
1444 fn to_summarize_request(
1445 &self,
1446 model: &Arc<dyn LanguageModel>,
1447 intent: CompletionIntent,
1448 added_user_message: String,
1449 cx: &App,
1450 ) -> LanguageModelRequest {
1451 let mut request = LanguageModelRequest {
1452 thread_id: None,
1453 prompt_id: None,
1454 intent: Some(intent),
1455 mode: None,
1456 messages: vec![],
1457 tools: Vec::new(),
1458 tool_choice: None,
1459 stop: Vec::new(),
1460 temperature: AgentSettings::temperature_for_model(model, cx),
1461 thinking_allowed: false,
1462 };
1463
1464 for message in &self.messages {
1465 let mut request_message = LanguageModelRequestMessage {
1466 role: message.role,
1467 content: Vec::new(),
1468 cache: false,
1469 };
1470
1471 for segment in &message.segments {
1472 match segment {
1473 MessageSegment::Text(text) => request_message
1474 .content
1475 .push(MessageContent::Text(text.clone())),
1476 MessageSegment::Thinking { .. } => {}
1477 MessageSegment::RedactedThinking(_) => {}
1478 }
1479 }
1480
1481 if request_message.content.is_empty() {
1482 continue;
1483 }
1484
1485 request.messages.push(request_message);
1486 }
1487
1488 request.messages.push(LanguageModelRequestMessage {
1489 role: Role::User,
1490 content: vec![MessageContent::Text(added_user_message)],
1491 cache: false,
1492 });
1493
1494 request
1495 }
1496
1497 /// Insert auto-generated notifications (if any) to the thread
1498 fn flush_notifications(
1499 &mut self,
1500 model: Arc<dyn LanguageModel>,
1501 intent: CompletionIntent,
1502 cx: &mut Context<Self>,
1503 ) {
1504 match intent {
1505 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1506 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1507 cx.emit(ThreadEvent::ToolFinished {
1508 tool_use_id: pending_tool_use.id.clone(),
1509 pending_tool_use: Some(pending_tool_use),
1510 });
1511 }
1512 }
1513 CompletionIntent::ThreadSummarization
1514 | CompletionIntent::ThreadContextSummarization
1515 | CompletionIntent::CreateFile
1516 | CompletionIntent::EditFile
1517 | CompletionIntent::InlineAssist
1518 | CompletionIntent::TerminalInlineAssist
1519 | CompletionIntent::GenerateGitCommitMessage => {}
1520 };
1521 }
1522
1523 fn attach_tracked_files_state(
1524 &mut self,
1525 model: Arc<dyn LanguageModel>,
1526 cx: &mut App,
1527 ) -> Option<PendingToolUse> {
1528 // Represent notification as a simulated `project_notifications` tool call
1529 let tool_name = Arc::from("project_notifications");
1530 let tool = self.tools.read(cx).tool(&tool_name, cx)?;
1531
1532 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1533 return None;
1534 }
1535
1536 if self
1537 .action_log
1538 .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
1539 {
1540 return None;
1541 }
1542
1543 let input = serde_json::json!({});
1544 let request = Arc::new(LanguageModelRequest::default()); // unused
1545 let window = None;
1546 let tool_result = tool.run(
1547 input,
1548 request,
1549 self.project.clone(),
1550 self.action_log.clone(),
1551 model.clone(),
1552 window,
1553 cx,
1554 );
1555
1556 let tool_use_id =
1557 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1558
1559 let tool_use = LanguageModelToolUse {
1560 id: tool_use_id.clone(),
1561 name: tool_name.clone(),
1562 raw_input: "{}".to_string(),
1563 input: serde_json::json!({}),
1564 is_input_complete: true,
1565 };
1566
1567 let tool_output = cx.background_executor().block(tool_result.output);
1568
1569 // Attach a project_notification tool call to the latest existing
1570 // Assistant message. We cannot create a new Assistant message
1571 // because thinking models require a `thinking` block that we
1572 // cannot mock. We cannot send a notification as a normal
1573 // (non-tool-use) User message because this distracts Agent
1574 // too much.
1575 let tool_message_id = self
1576 .messages
1577 .iter()
1578 .enumerate()
1579 .rfind(|(_, message)| message.role == Role::Assistant)
1580 .map(|(_, message)| message.id)?;
1581
1582 let tool_use_metadata = ToolUseMetadata {
1583 model: model.clone(),
1584 thread_id: self.id.clone(),
1585 prompt_id: self.last_prompt_id.clone(),
1586 };
1587
1588 self.tool_use
1589 .request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx);
1590
1591 self.tool_use.insert_tool_output(
1592 tool_use_id,
1593 tool_name,
1594 tool_output,
1595 self.configured_model.as_ref(),
1596 self.completion_mode,
1597 )
1598 }
1599
1600 pub fn stream_completion(
1601 &mut self,
1602 request: LanguageModelRequest,
1603 model: Arc<dyn LanguageModel>,
1604 intent: CompletionIntent,
1605 window: Option<AnyWindowHandle>,
1606 cx: &mut Context<Self>,
1607 ) {
1608 self.tool_use_limit_reached = false;
1609
1610 let pending_completion_id = post_inc(&mut self.completion_count);
1611 let mut request_callback_parameters = if self.request_callback.is_some() {
1612 Some((request.clone(), Vec::new()))
1613 } else {
1614 None
1615 };
1616 let prompt_id = self.last_prompt_id.clone();
1617 let tool_use_metadata = ToolUseMetadata {
1618 model: model.clone(),
1619 thread_id: self.id.clone(),
1620 prompt_id: prompt_id.clone(),
1621 };
1622
1623 let completion_mode = request
1624 .mode
1625 .unwrap_or(cloud_llm_client::CompletionMode::Normal);
1626
1627 self.last_received_chunk_at = Some(Instant::now());
1628
1629 let task = cx.spawn(async move |thread, cx| {
1630 let stream_completion_future = model.stream_completion(request, cx);
1631 let initial_token_usage =
1632 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1633 let stream_completion = async {
1634 let mut events = stream_completion_future.await?;
1635
1636 let mut stop_reason = StopReason::EndTurn;
1637 let mut current_token_usage = TokenUsage::default();
1638
1639 thread
1640 .update(cx, |_thread, cx| {
1641 cx.emit(ThreadEvent::NewRequest);
1642 })
1643 .ok();
1644
1645 let mut request_assistant_message_id = None;
1646
1647 while let Some(event) = events.next().await {
1648 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1649 response_events
1650 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1651 }
1652
1653 thread.update(cx, |thread, cx| {
1654 match event? {
1655 LanguageModelCompletionEvent::StartMessage { .. } => {
1656 request_assistant_message_id =
1657 Some(thread.insert_assistant_message(
1658 vec![MessageSegment::Text(String::new())],
1659 cx,
1660 ));
1661 }
1662 LanguageModelCompletionEvent::Stop(reason) => {
1663 stop_reason = reason;
1664 }
1665 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1666 thread.update_token_usage_at_last_message(token_usage);
1667 thread.cumulative_token_usage = thread.cumulative_token_usage
1668 + token_usage
1669 - current_token_usage;
1670 current_token_usage = token_usage;
1671 }
1672 LanguageModelCompletionEvent::Text(chunk) => {
1673 thread.received_chunk();
1674
1675 cx.emit(ThreadEvent::ReceivedTextChunk);
1676 if let Some(last_message) = thread.messages.last_mut() {
1677 if last_message.role == Role::Assistant
1678 && !thread.tool_use.has_tool_results(last_message.id)
1679 {
1680 last_message.push_text(&chunk);
1681 cx.emit(ThreadEvent::StreamedAssistantText(
1682 last_message.id,
1683 chunk,
1684 ));
1685 } else {
1686 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1687 // of a new Assistant response.
1688 //
1689 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1690 // will result in duplicating the text of the chunk in the rendered Markdown.
1691 request_assistant_message_id =
1692 Some(thread.insert_assistant_message(
1693 vec![MessageSegment::Text(chunk.to_string())],
1694 cx,
1695 ));
1696 };
1697 }
1698 }
1699 LanguageModelCompletionEvent::Thinking {
1700 text: chunk,
1701 signature,
1702 } => {
1703 thread.received_chunk();
1704
1705 if let Some(last_message) = thread.messages.last_mut() {
1706 if last_message.role == Role::Assistant
1707 && !thread.tool_use.has_tool_results(last_message.id)
1708 {
1709 last_message.push_thinking(&chunk, signature);
1710 cx.emit(ThreadEvent::StreamedAssistantThinking(
1711 last_message.id,
1712 chunk,
1713 ));
1714 } else {
1715 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1716 // of a new Assistant response.
1717 //
1718 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1719 // will result in duplicating the text of the chunk in the rendered Markdown.
1720 request_assistant_message_id =
1721 Some(thread.insert_assistant_message(
1722 vec![MessageSegment::Thinking {
1723 text: chunk.to_string(),
1724 signature,
1725 }],
1726 cx,
1727 ));
1728 };
1729 }
1730 }
1731 LanguageModelCompletionEvent::RedactedThinking { data } => {
1732 thread.received_chunk();
1733
1734 if let Some(last_message) = thread.messages.last_mut() {
1735 if last_message.role == Role::Assistant
1736 && !thread.tool_use.has_tool_results(last_message.id)
1737 {
1738 last_message.push_redacted_thinking(data);
1739 } else {
1740 request_assistant_message_id =
1741 Some(thread.insert_assistant_message(
1742 vec![MessageSegment::RedactedThinking(data)],
1743 cx,
1744 ));
1745 };
1746 }
1747 }
1748 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1749 let last_assistant_message_id = request_assistant_message_id
1750 .unwrap_or_else(|| {
1751 let new_assistant_message_id =
1752 thread.insert_assistant_message(vec![], cx);
1753 request_assistant_message_id =
1754 Some(new_assistant_message_id);
1755 new_assistant_message_id
1756 });
1757
1758 let tool_use_id = tool_use.id.clone();
1759 let streamed_input = if tool_use.is_input_complete {
1760 None
1761 } else {
1762 Some(tool_use.input.clone())
1763 };
1764
1765 let ui_text = thread.tool_use.request_tool_use(
1766 last_assistant_message_id,
1767 tool_use,
1768 tool_use_metadata.clone(),
1769 cx,
1770 );
1771
1772 if let Some(input) = streamed_input {
1773 cx.emit(ThreadEvent::StreamedToolUse {
1774 tool_use_id,
1775 ui_text,
1776 input,
1777 });
1778 }
1779 }
1780 LanguageModelCompletionEvent::ToolUseJsonParseError {
1781 id,
1782 tool_name,
1783 raw_input: invalid_input_json,
1784 json_parse_error,
1785 } => {
1786 thread.receive_invalid_tool_json(
1787 id,
1788 tool_name,
1789 invalid_input_json,
1790 json_parse_error,
1791 window,
1792 cx,
1793 );
1794 }
1795 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1796 if let Some(completion) = thread
1797 .pending_completions
1798 .iter_mut()
1799 .find(|completion| completion.id == pending_completion_id)
1800 {
1801 match status_update {
1802 CompletionRequestStatus::Queued { position } => {
1803 completion.queue_state =
1804 QueueState::Queued { position };
1805 }
1806 CompletionRequestStatus::Started => {
1807 completion.queue_state = QueueState::Started;
1808 }
1809 CompletionRequestStatus::Failed {
1810 code,
1811 message,
1812 request_id: _,
1813 retry_after,
1814 } => {
1815 return Err(
1816 LanguageModelCompletionError::from_cloud_failure(
1817 model.upstream_provider_name(),
1818 code,
1819 message,
1820 retry_after.map(Duration::from_secs_f64),
1821 ),
1822 );
1823 }
1824 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1825 thread.update_model_request_usage(
1826 amount as u32,
1827 limit,
1828 cx,
1829 );
1830 }
1831 CompletionRequestStatus::ToolUseLimitReached => {
1832 thread.tool_use_limit_reached = true;
1833 cx.emit(ThreadEvent::ToolUseLimitReached);
1834 }
1835 }
1836 }
1837 }
1838 }
1839
1840 thread.touch_updated_at();
1841 cx.emit(ThreadEvent::StreamedCompletion);
1842 cx.notify();
1843
1844 Ok(())
1845 })??;
1846
1847 smol::future::yield_now().await;
1848 }
1849
1850 thread.update(cx, |thread, cx| {
1851 thread.last_received_chunk_at = None;
1852 thread
1853 .pending_completions
1854 .retain(|completion| completion.id != pending_completion_id);
1855
1856 // If there is a response without tool use, summarize the message. Otherwise,
1857 // allow two tool uses before summarizing.
1858 if matches!(thread.summary, ThreadSummary::Pending)
1859 && thread.messages.len() >= 2
1860 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1861 {
1862 thread.summarize(cx);
1863 }
1864 })?;
1865
1866 anyhow::Ok(stop_reason)
1867 };
1868
1869 let result = stream_completion.await;
1870 let mut retry_scheduled = false;
1871
1872 thread
1873 .update(cx, |thread, cx| {
1874 thread.finalize_pending_checkpoint(cx);
1875 match result.as_ref() {
1876 Ok(stop_reason) => {
1877 match stop_reason {
1878 StopReason::ToolUse => {
1879 let tool_uses =
1880 thread.use_pending_tools(window, model.clone(), cx);
1881 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1882 }
1883 StopReason::EndTurn | StopReason::MaxTokens => {
1884 thread.project.update(cx, |project, cx| {
1885 project.set_agent_location(None, cx);
1886 });
1887 }
1888 StopReason::Refusal => {
1889 thread.project.update(cx, |project, cx| {
1890 project.set_agent_location(None, cx);
1891 });
1892
1893 // Remove the turn that was refused.
1894 //
1895 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1896 {
1897 let mut messages_to_remove = Vec::new();
1898
1899 for (ix, message) in
1900 thread.messages.iter().enumerate().rev()
1901 {
1902 messages_to_remove.push(message.id);
1903
1904 if message.role == Role::User {
1905 if ix == 0 {
1906 break;
1907 }
1908
1909 if let Some(prev_message) =
1910 thread.messages.get(ix - 1)
1911 && prev_message.role == Role::Assistant {
1912 break;
1913 }
1914 }
1915 }
1916
1917 for message_id in messages_to_remove {
1918 thread.delete_message(message_id, cx);
1919 }
1920 }
1921
1922 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1923 header: "Language model refusal".into(),
1924 message:
1925 "Model refused to generate content for safety reasons."
1926 .into(),
1927 }));
1928 }
1929 }
1930
1931 // We successfully completed, so cancel any remaining retries.
1932 thread.retry_state = None;
1933 }
1934 Err(error) => {
1935 thread.project.update(cx, |project, cx| {
1936 project.set_agent_location(None, cx);
1937 });
1938
1939 if error.is::<PaymentRequiredError>() {
1940 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1941 } else if let Some(error) =
1942 error.downcast_ref::<ModelRequestLimitReachedError>()
1943 {
1944 cx.emit(ThreadEvent::ShowError(
1945 ThreadError::ModelRequestLimitReached { plan: error.plan },
1946 ));
1947 } else if let Some(completion_error) =
1948 error.downcast_ref::<LanguageModelCompletionError>()
1949 {
1950 match &completion_error {
1951 LanguageModelCompletionError::PromptTooLarge {
1952 tokens, ..
1953 } => {
1954 let tokens = tokens.unwrap_or_else(|| {
1955 // We didn't get an exact token count from the API, so fall back on our estimate.
1956 thread
1957 .total_token_usage()
1958 .map(|usage| usage.total)
1959 .unwrap_or(0)
1960 // We know the context window was exceeded in practice, so if our estimate was
1961 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1962 .max(
1963 model
1964 .max_token_count_for_mode(completion_mode)
1965 .saturating_add(1),
1966 )
1967 });
1968 thread.exceeded_window_error = Some(ExceededWindowError {
1969 model_id: model.id(),
1970 token_count: tokens,
1971 });
1972 cx.notify();
1973 }
1974 _ => {
1975 if let Some(retry_strategy) =
1976 Thread::get_retry_strategy(completion_error)
1977 {
1978 log::info!(
1979 "Retrying with {:?} for language model completion error {:?}",
1980 retry_strategy,
1981 completion_error
1982 );
1983
1984 retry_scheduled = thread
1985 .handle_retryable_error_with_delay(
1986 completion_error,
1987 Some(retry_strategy),
1988 model.clone(),
1989 intent,
1990 window,
1991 cx,
1992 );
1993 }
1994 }
1995 }
1996 }
1997
1998 if !retry_scheduled {
1999 thread.cancel_last_completion(window, cx);
2000 }
2001 }
2002 }
2003
2004 if !retry_scheduled {
2005 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2006 }
2007
2008 if let Some((request_callback, (request, response_events))) = thread
2009 .request_callback
2010 .as_mut()
2011 .zip(request_callback_parameters.as_ref())
2012 {
2013 request_callback(request, response_events);
2014 }
2015
2016 if let Ok(initial_usage) = initial_token_usage {
2017 let usage = thread.cumulative_token_usage - initial_usage;
2018
2019 telemetry::event!(
2020 "Assistant Thread Completion",
2021 thread_id = thread.id().to_string(),
2022 prompt_id = prompt_id,
2023 model = model.telemetry_id(),
2024 model_provider = model.provider_id().to_string(),
2025 input_tokens = usage.input_tokens,
2026 output_tokens = usage.output_tokens,
2027 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2028 cache_read_input_tokens = usage.cache_read_input_tokens,
2029 );
2030 }
2031 })
2032 .ok();
2033 });
2034
2035 self.pending_completions.push(PendingCompletion {
2036 id: pending_completion_id,
2037 queue_state: QueueState::Sending,
2038 _task: task,
2039 });
2040 }
2041
2042 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2043 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2044 println!("No thread summary model");
2045 return;
2046 };
2047
2048 if !model.provider.is_authenticated(cx) {
2049 return;
2050 }
2051
2052 let request = self.to_summarize_request(
2053 &model.model,
2054 CompletionIntent::ThreadSummarization,
2055 SUMMARIZE_THREAD_PROMPT.into(),
2056 cx,
2057 );
2058
2059 self.summary = ThreadSummary::Generating;
2060
2061 self.pending_summary = cx.spawn(async move |this, cx| {
2062 let result = async {
2063 let mut messages = model.model.stream_completion(request, cx).await?;
2064
2065 let mut new_summary = String::new();
2066 while let Some(event) = messages.next().await {
2067 let Ok(event) = event else {
2068 continue;
2069 };
2070 let text = match event {
2071 LanguageModelCompletionEvent::Text(text) => text,
2072 LanguageModelCompletionEvent::StatusUpdate(
2073 CompletionRequestStatus::UsageUpdated { amount, limit },
2074 ) => {
2075 this.update(cx, |thread, cx| {
2076 thread.update_model_request_usage(amount as u32, limit, cx);
2077 })?;
2078 continue;
2079 }
2080 _ => continue,
2081 };
2082
2083 let mut lines = text.lines();
2084 new_summary.extend(lines.next());
2085
2086 // Stop if the LLM generated multiple lines.
2087 if lines.next().is_some() {
2088 break;
2089 }
2090 }
2091
2092 anyhow::Ok(new_summary)
2093 }
2094 .await;
2095
2096 this.update(cx, |this, cx| {
2097 match result {
2098 Ok(new_summary) => {
2099 if new_summary.is_empty() {
2100 this.summary = ThreadSummary::Error;
2101 } else {
2102 this.summary = ThreadSummary::Ready(new_summary.into());
2103 }
2104 }
2105 Err(err) => {
2106 this.summary = ThreadSummary::Error;
2107 log::error!("Failed to generate thread summary: {}", err);
2108 }
2109 }
2110 cx.emit(ThreadEvent::SummaryGenerated);
2111 })
2112 .log_err()?;
2113
2114 Some(())
2115 });
2116 }
2117
2118 fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2119 use LanguageModelCompletionError::*;
2120
2121 // General strategy here:
2122 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2123 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2124 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2125 match error {
2126 HttpResponseError {
2127 status_code: StatusCode::TOO_MANY_REQUESTS,
2128 ..
2129 } => Some(RetryStrategy::ExponentialBackoff {
2130 initial_delay: BASE_RETRY_DELAY,
2131 max_attempts: MAX_RETRY_ATTEMPTS,
2132 }),
2133 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2134 Some(RetryStrategy::Fixed {
2135 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2136 max_attempts: MAX_RETRY_ATTEMPTS,
2137 })
2138 }
2139 UpstreamProviderError {
2140 status,
2141 retry_after,
2142 ..
2143 } => match *status {
2144 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2145 Some(RetryStrategy::Fixed {
2146 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2147 max_attempts: MAX_RETRY_ATTEMPTS,
2148 })
2149 }
2150 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2151 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2152 // Internal Server Error could be anything, retry up to 3 times.
2153 max_attempts: 3,
2154 }),
2155 status => {
2156 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2157 // but we frequently get them in practice. See https://http.dev/529
2158 if status.as_u16() == 529 {
2159 Some(RetryStrategy::Fixed {
2160 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2161 max_attempts: MAX_RETRY_ATTEMPTS,
2162 })
2163 } else {
2164 Some(RetryStrategy::Fixed {
2165 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2166 max_attempts: 2,
2167 })
2168 }
2169 }
2170 },
2171 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2172 delay: BASE_RETRY_DELAY,
2173 max_attempts: 3,
2174 }),
2175 ApiReadResponseError { .. }
2176 | HttpSend { .. }
2177 | DeserializeResponse { .. }
2178 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2179 delay: BASE_RETRY_DELAY,
2180 max_attempts: 3,
2181 }),
2182 // Retrying these errors definitely shouldn't help.
2183 HttpResponseError {
2184 status_code:
2185 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2186 ..
2187 }
2188 | AuthenticationError { .. }
2189 | PermissionError { .. }
2190 | NoApiKey { .. }
2191 | ApiEndpointNotFound { .. }
2192 | PromptTooLarge { .. } => None,
2193 // These errors might be transient, so retry them
2194 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2195 delay: BASE_RETRY_DELAY,
2196 max_attempts: 1,
2197 }),
2198 // Retry all other 4xx and 5xx errors once.
2199 HttpResponseError { status_code, .. }
2200 if status_code.is_client_error() || status_code.is_server_error() =>
2201 {
2202 Some(RetryStrategy::Fixed {
2203 delay: BASE_RETRY_DELAY,
2204 max_attempts: 3,
2205 })
2206 }
2207 Other(err)
2208 if err.is::<PaymentRequiredError>()
2209 || err.is::<ModelRequestLimitReachedError>() =>
2210 {
2211 // Retrying won't help for Payment Required or Model Request Limit errors (where
2212 // the user must upgrade to usage-based billing to get more requests, or else wait
2213 // for a significant amount of time for the request limit to reset).
2214 None
2215 }
2216 // Conservatively assume that any other errors are non-retryable
2217 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2218 delay: BASE_RETRY_DELAY,
2219 max_attempts: 2,
2220 }),
2221 }
2222 }
2223
2224 fn handle_retryable_error_with_delay(
2225 &mut self,
2226 error: &LanguageModelCompletionError,
2227 strategy: Option<RetryStrategy>,
2228 model: Arc<dyn LanguageModel>,
2229 intent: CompletionIntent,
2230 window: Option<AnyWindowHandle>,
2231 cx: &mut Context<Self>,
2232 ) -> bool {
2233 // Store context for the Retry button
2234 self.last_error_context = Some((model.clone(), intent));
2235
2236 // Only auto-retry if Burn Mode is enabled
2237 if self.completion_mode != CompletionMode::Burn {
2238 // Show error with retry options
2239 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2240 message: format!(
2241 "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
2242 error
2243 )
2244 .into(),
2245 can_enable_burn_mode: true,
2246 }));
2247 return false;
2248 }
2249
2250 let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2251 return false;
2252 };
2253
2254 let max_attempts = match &strategy {
2255 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2256 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2257 };
2258
2259 let retry_state = self.retry_state.get_or_insert(RetryState {
2260 attempt: 0,
2261 max_attempts,
2262 intent,
2263 });
2264
2265 retry_state.attempt += 1;
2266 let attempt = retry_state.attempt;
2267 let max_attempts = retry_state.max_attempts;
2268 let intent = retry_state.intent;
2269
2270 if attempt <= max_attempts {
2271 let delay = match &strategy {
2272 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2273 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2274 Duration::from_secs(delay_secs)
2275 }
2276 RetryStrategy::Fixed { delay, .. } => *delay,
2277 };
2278
2279 // Add a transient message to inform the user
2280 let delay_secs = delay.as_secs();
2281 let retry_message = if max_attempts == 1 {
2282 format!("{error}. Retrying in {delay_secs} seconds...")
2283 } else {
2284 format!(
2285 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2286 in {delay_secs} seconds..."
2287 )
2288 };
2289 log::warn!(
2290 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2291 in {delay_secs} seconds: {error:?}",
2292 );
2293
2294 // Add a UI-only message instead of a regular message
2295 let id = self.next_message_id.post_inc();
2296 self.messages.push(Message {
2297 id,
2298 role: Role::System,
2299 segments: vec![MessageSegment::Text(retry_message)],
2300 loaded_context: LoadedContext::default(),
2301 creases: Vec::new(),
2302 is_hidden: false,
2303 ui_only: true,
2304 });
2305 cx.emit(ThreadEvent::MessageAdded(id));
2306
2307 // Schedule the retry
2308 let thread_handle = cx.entity().downgrade();
2309
2310 cx.spawn(async move |_thread, cx| {
2311 cx.background_executor().timer(delay).await;
2312
2313 thread_handle
2314 .update(cx, |thread, cx| {
2315 // Retry the completion
2316 thread.send_to_model(model, intent, window, cx);
2317 })
2318 .log_err();
2319 })
2320 .detach();
2321
2322 true
2323 } else {
2324 // Max retries exceeded
2325 self.retry_state = None;
2326
2327 // Stop generating since we're giving up on retrying.
2328 self.pending_completions.clear();
2329
2330 // Show error alongside a Retry button, but no
2331 // Enable Burn Mode button (since it's already enabled)
2332 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2333 message: format!("Failed after retrying: {}", error).into(),
2334 can_enable_burn_mode: false,
2335 }));
2336
2337 false
2338 }
2339 }
2340
2341 pub fn start_generating_detailed_summary_if_needed(
2342 &mut self,
2343 thread_store: WeakEntity<ThreadStore>,
2344 cx: &mut Context<Self>,
2345 ) {
2346 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2347 return;
2348 };
2349
2350 match &*self.detailed_summary_rx.borrow() {
2351 DetailedSummaryState::Generating { message_id, .. }
2352 | DetailedSummaryState::Generated { message_id, .. }
2353 if *message_id == last_message_id =>
2354 {
2355 // Already up-to-date
2356 return;
2357 }
2358 _ => {}
2359 }
2360
2361 let Some(ConfiguredModel { model, provider }) =
2362 LanguageModelRegistry::read_global(cx).thread_summary_model()
2363 else {
2364 return;
2365 };
2366
2367 if !provider.is_authenticated(cx) {
2368 return;
2369 }
2370
2371 let request = self.to_summarize_request(
2372 &model,
2373 CompletionIntent::ThreadContextSummarization,
2374 SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
2375 cx,
2376 );
2377
2378 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2379 message_id: last_message_id,
2380 };
2381
2382 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2383 // be better to allow the old task to complete, but this would require logic for choosing
2384 // which result to prefer (the old task could complete after the new one, resulting in a
2385 // stale summary).
2386 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2387 let stream = model.stream_completion_text(request, cx);
2388 let Some(mut messages) = stream.await.log_err() else {
2389 thread
2390 .update(cx, |thread, _cx| {
2391 *thread.detailed_summary_tx.borrow_mut() =
2392 DetailedSummaryState::NotGenerated;
2393 })
2394 .ok()?;
2395 return None;
2396 };
2397
2398 let mut new_detailed_summary = String::new();
2399
2400 while let Some(chunk) = messages.stream.next().await {
2401 if let Some(chunk) = chunk.log_err() {
2402 new_detailed_summary.push_str(&chunk);
2403 }
2404 }
2405
2406 thread
2407 .update(cx, |thread, _cx| {
2408 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2409 text: new_detailed_summary.into(),
2410 message_id: last_message_id,
2411 };
2412 })
2413 .ok()?;
2414
2415 // Save thread so its summary can be reused later
2416 if let Some(thread) = thread.upgrade()
2417 && let Ok(Ok(save_task)) = cx.update(|cx| {
2418 thread_store
2419 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2420 })
2421 {
2422 save_task.await.log_err();
2423 }
2424
2425 Some(())
2426 });
2427 }
2428
2429 pub async fn wait_for_detailed_summary_or_text(
2430 this: &Entity<Self>,
2431 cx: &mut AsyncApp,
2432 ) -> Option<SharedString> {
2433 let mut detailed_summary_rx = this
2434 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2435 .ok()?;
2436 loop {
2437 match detailed_summary_rx.recv().await? {
2438 DetailedSummaryState::Generating { .. } => {}
2439 DetailedSummaryState::NotGenerated => {
2440 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2441 }
2442 DetailedSummaryState::Generated { text, .. } => return Some(text),
2443 }
2444 }
2445 }
2446
2447 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2448 self.detailed_summary_rx
2449 .borrow()
2450 .text()
2451 .unwrap_or_else(|| self.text().into())
2452 }
2453
2454 pub fn is_generating_detailed_summary(&self) -> bool {
2455 matches!(
2456 &*self.detailed_summary_rx.borrow(),
2457 DetailedSummaryState::Generating { .. }
2458 )
2459 }
2460
2461 pub fn use_pending_tools(
2462 &mut self,
2463 window: Option<AnyWindowHandle>,
2464 model: Arc<dyn LanguageModel>,
2465 cx: &mut Context<Self>,
2466 ) -> Vec<PendingToolUse> {
2467 let request =
2468 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2469 let pending_tool_uses = self
2470 .tool_use
2471 .pending_tool_uses()
2472 .into_iter()
2473 .filter(|tool_use| tool_use.status.is_idle())
2474 .cloned()
2475 .collect::<Vec<_>>();
2476
2477 for tool_use in pending_tool_uses.iter() {
2478 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2479 }
2480
2481 pending_tool_uses
2482 }
2483
2484 fn use_pending_tool(
2485 &mut self,
2486 tool_use: PendingToolUse,
2487 request: Arc<LanguageModelRequest>,
2488 model: Arc<dyn LanguageModel>,
2489 window: Option<AnyWindowHandle>,
2490 cx: &mut Context<Self>,
2491 ) {
2492 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2493 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2494 };
2495
2496 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2497 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2498 }
2499
2500 if tool.needs_confirmation(&tool_use.input, &self.project, cx)
2501 && !AgentSettings::get_global(cx).always_allow_tool_actions
2502 {
2503 self.tool_use.confirm_tool_use(
2504 tool_use.id,
2505 tool_use.ui_text,
2506 tool_use.input,
2507 request,
2508 tool,
2509 );
2510 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2511 } else {
2512 self.run_tool(
2513 tool_use.id,
2514 tool_use.ui_text,
2515 tool_use.input,
2516 request,
2517 tool,
2518 model,
2519 window,
2520 cx,
2521 );
2522 }
2523 }
2524
2525 pub fn handle_hallucinated_tool_use(
2526 &mut self,
2527 tool_use_id: LanguageModelToolUseId,
2528 hallucinated_tool_name: Arc<str>,
2529 window: Option<AnyWindowHandle>,
2530 cx: &mut Context<Thread>,
2531 ) {
2532 let available_tools = self.profile.enabled_tools(cx);
2533
2534 let tool_list = available_tools
2535 .iter()
2536 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2537 .collect::<Vec<_>>()
2538 .join("\n");
2539
2540 let error_message = format!(
2541 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2542 hallucinated_tool_name, tool_list
2543 );
2544
2545 let pending_tool_use = self.tool_use.insert_tool_output(
2546 tool_use_id.clone(),
2547 hallucinated_tool_name,
2548 Err(anyhow!("Missing tool call: {error_message}")),
2549 self.configured_model.as_ref(),
2550 self.completion_mode,
2551 );
2552
2553 cx.emit(ThreadEvent::MissingToolUse {
2554 tool_use_id: tool_use_id.clone(),
2555 ui_text: error_message.into(),
2556 });
2557
2558 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2559 }
2560
2561 pub fn receive_invalid_tool_json(
2562 &mut self,
2563 tool_use_id: LanguageModelToolUseId,
2564 tool_name: Arc<str>,
2565 invalid_json: Arc<str>,
2566 error: String,
2567 window: Option<AnyWindowHandle>,
2568 cx: &mut Context<Thread>,
2569 ) {
2570 log::error!("The model returned invalid input JSON: {invalid_json}");
2571
2572 let pending_tool_use = self.tool_use.insert_tool_output(
2573 tool_use_id.clone(),
2574 tool_name,
2575 Err(anyhow!("Error parsing input JSON: {error}")),
2576 self.configured_model.as_ref(),
2577 self.completion_mode,
2578 );
2579 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2580 pending_tool_use.ui_text.clone()
2581 } else {
2582 log::error!(
2583 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2584 );
2585 format!("Unknown tool {}", tool_use_id).into()
2586 };
2587
2588 cx.emit(ThreadEvent::InvalidToolInput {
2589 tool_use_id: tool_use_id.clone(),
2590 ui_text,
2591 invalid_input_json: invalid_json,
2592 });
2593
2594 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2595 }
2596
2597 pub fn run_tool(
2598 &mut self,
2599 tool_use_id: LanguageModelToolUseId,
2600 ui_text: impl Into<SharedString>,
2601 input: serde_json::Value,
2602 request: Arc<LanguageModelRequest>,
2603 tool: Arc<dyn Tool>,
2604 model: Arc<dyn LanguageModel>,
2605 window: Option<AnyWindowHandle>,
2606 cx: &mut Context<Thread>,
2607 ) {
2608 let task =
2609 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2610 self.tool_use
2611 .run_pending_tool(tool_use_id, ui_text.into(), task);
2612 }
2613
2614 fn spawn_tool_use(
2615 &mut self,
2616 tool_use_id: LanguageModelToolUseId,
2617 request: Arc<LanguageModelRequest>,
2618 input: serde_json::Value,
2619 tool: Arc<dyn Tool>,
2620 model: Arc<dyn LanguageModel>,
2621 window: Option<AnyWindowHandle>,
2622 cx: &mut Context<Thread>,
2623 ) -> Task<()> {
2624 let tool_name: Arc<str> = tool.name().into();
2625
2626 let tool_result = tool.run(
2627 input,
2628 request,
2629 self.project.clone(),
2630 self.action_log.clone(),
2631 model,
2632 window,
2633 cx,
2634 );
2635
2636 // Store the card separately if it exists
2637 if let Some(card) = tool_result.card.clone() {
2638 self.tool_use
2639 .insert_tool_result_card(tool_use_id.clone(), card);
2640 }
2641
2642 cx.spawn({
2643 async move |thread: WeakEntity<Thread>, cx| {
2644 let output = tool_result.output.await;
2645
2646 thread
2647 .update(cx, |thread, cx| {
2648 let pending_tool_use = thread.tool_use.insert_tool_output(
2649 tool_use_id.clone(),
2650 tool_name,
2651 output,
2652 thread.configured_model.as_ref(),
2653 thread.completion_mode,
2654 );
2655 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2656 })
2657 .ok();
2658 }
2659 })
2660 }
2661
2662 fn tool_finished(
2663 &mut self,
2664 tool_use_id: LanguageModelToolUseId,
2665 pending_tool_use: Option<PendingToolUse>,
2666 canceled: bool,
2667 window: Option<AnyWindowHandle>,
2668 cx: &mut Context<Self>,
2669 ) {
2670 if self.all_tools_finished()
2671 && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
2672 && !canceled
2673 {
2674 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2675 }
2676
2677 cx.emit(ThreadEvent::ToolFinished {
2678 tool_use_id,
2679 pending_tool_use,
2680 });
2681 }
2682
2683 /// Cancels the last pending completion, if there are any pending.
2684 ///
2685 /// Returns whether a completion was canceled.
2686 pub fn cancel_last_completion(
2687 &mut self,
2688 window: Option<AnyWindowHandle>,
2689 cx: &mut Context<Self>,
2690 ) -> bool {
2691 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2692
2693 self.retry_state = None;
2694
2695 for pending_tool_use in self.tool_use.cancel_pending() {
2696 canceled = true;
2697 self.tool_finished(
2698 pending_tool_use.id.clone(),
2699 Some(pending_tool_use),
2700 true,
2701 window,
2702 cx,
2703 );
2704 }
2705
2706 if canceled {
2707 cx.emit(ThreadEvent::CompletionCanceled);
2708
2709 // When canceled, we always want to insert the checkpoint.
2710 // (We skip over finalize_pending_checkpoint, because it
2711 // would conclude we didn't have anything to insert here.)
2712 if let Some(checkpoint) = self.pending_checkpoint.take() {
2713 self.insert_checkpoint(checkpoint, cx);
2714 }
2715 } else {
2716 self.finalize_pending_checkpoint(cx);
2717 }
2718
2719 canceled
2720 }
2721
2722 /// Signals that any in-progress editing should be canceled.
2723 ///
2724 /// This method is used to notify listeners (like ActiveThread) that
2725 /// they should cancel any editing operations.
2726 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2727 cx.emit(ThreadEvent::CancelEditing);
2728 }
2729
2730 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2731 self.message_feedback.get(&message_id).copied()
2732 }
2733
2734 pub fn report_message_feedback(
2735 &mut self,
2736 message_id: MessageId,
2737 feedback: ThreadFeedback,
2738 cx: &mut Context<Self>,
2739 ) -> Task<Result<()>> {
2740 if self.message_feedback.get(&message_id) == Some(&feedback) {
2741 return Task::ready(Ok(()));
2742 }
2743
2744 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2745 let serialized_thread = self.serialize(cx);
2746 let thread_id = self.id().clone();
2747 let client = self.project.read(cx).client();
2748
2749 let enabled_tool_names: Vec<String> = self
2750 .profile
2751 .enabled_tools(cx)
2752 .iter()
2753 .map(|(name, _)| name.clone().into())
2754 .collect();
2755
2756 self.message_feedback.insert(message_id, feedback);
2757
2758 cx.notify();
2759
2760 let message_content = self
2761 .message(message_id)
2762 .map(|msg| msg.to_message_content())
2763 .unwrap_or_default();
2764
2765 cx.background_spawn(async move {
2766 let final_project_snapshot = final_project_snapshot.await;
2767 let serialized_thread = serialized_thread.await?;
2768 let thread_data =
2769 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2770
2771 let rating = match feedback {
2772 ThreadFeedback::Positive => "positive",
2773 ThreadFeedback::Negative => "negative",
2774 };
2775 telemetry::event!(
2776 "Assistant Thread Rated",
2777 rating,
2778 thread_id,
2779 enabled_tool_names,
2780 message_id = message_id.0,
2781 message_content,
2782 thread_data,
2783 final_project_snapshot
2784 );
2785 client.telemetry().flush_events().await;
2786
2787 Ok(())
2788 })
2789 }
2790
2791 /// Create a snapshot of the current project state including git information and unsaved buffers.
2792 fn project_snapshot(
2793 project: Entity<Project>,
2794 cx: &mut Context<Self>,
2795 ) -> Task<Arc<ProjectSnapshot>> {
2796 let git_store = project.read(cx).git_store().clone();
2797 let worktree_snapshots: Vec<_> = project
2798 .read(cx)
2799 .visible_worktrees(cx)
2800 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2801 .collect();
2802
2803 cx.spawn(async move |_, _| {
2804 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2805
2806 Arc::new(ProjectSnapshot {
2807 worktree_snapshots,
2808 timestamp: Utc::now(),
2809 })
2810 })
2811 }
2812
2813 fn worktree_snapshot(
2814 worktree: Entity<project::Worktree>,
2815 git_store: Entity<GitStore>,
2816 cx: &App,
2817 ) -> Task<WorktreeSnapshot> {
2818 cx.spawn(async move |cx| {
2819 // Get worktree path and snapshot
2820 let worktree_info = cx.update(|app_cx| {
2821 let worktree = worktree.read(app_cx);
2822 let path = worktree.abs_path().to_string_lossy().into_owned();
2823 let snapshot = worktree.snapshot();
2824 (path, snapshot)
2825 });
2826
2827 let Ok((worktree_path, _snapshot)) = worktree_info else {
2828 return WorktreeSnapshot {
2829 worktree_path: String::new(),
2830 git_state: None,
2831 };
2832 };
2833
2834 let git_state = git_store
2835 .update(cx, |git_store, cx| {
2836 git_store
2837 .repositories()
2838 .values()
2839 .find(|repo| {
2840 repo.read(cx)
2841 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2842 .is_some()
2843 })
2844 .cloned()
2845 })
2846 .ok()
2847 .flatten()
2848 .map(|repo| {
2849 repo.update(cx, |repo, _| {
2850 let current_branch =
2851 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2852 repo.send_job(None, |state, _| async move {
2853 let RepositoryState::Local { backend, .. } = state else {
2854 return GitState {
2855 remote_url: None,
2856 head_sha: None,
2857 current_branch,
2858 diff: None,
2859 };
2860 };
2861
2862 let remote_url = backend.remote_url("origin");
2863 let head_sha = backend.head_sha().await;
2864 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2865
2866 GitState {
2867 remote_url,
2868 head_sha,
2869 current_branch,
2870 diff,
2871 }
2872 })
2873 })
2874 });
2875
2876 let git_state = match git_state {
2877 Some(git_state) => match git_state.ok() {
2878 Some(git_state) => git_state.await.ok(),
2879 None => None,
2880 },
2881 None => None,
2882 };
2883
2884 WorktreeSnapshot {
2885 worktree_path,
2886 git_state,
2887 }
2888 })
2889 }
2890
2891 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2892 let mut markdown = Vec::new();
2893
2894 let summary = self.summary().or_default();
2895 writeln!(markdown, "# {summary}\n")?;
2896
2897 for message in self.messages() {
2898 writeln!(
2899 markdown,
2900 "## {role}\n",
2901 role = match message.role {
2902 Role::User => "User",
2903 Role::Assistant => "Agent",
2904 Role::System => "System",
2905 }
2906 )?;
2907
2908 if !message.loaded_context.text.is_empty() {
2909 writeln!(markdown, "{}", message.loaded_context.text)?;
2910 }
2911
2912 if !message.loaded_context.images.is_empty() {
2913 writeln!(
2914 markdown,
2915 "\n{} images attached as context.\n",
2916 message.loaded_context.images.len()
2917 )?;
2918 }
2919
2920 for segment in &message.segments {
2921 match segment {
2922 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2923 MessageSegment::Thinking { text, .. } => {
2924 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2925 }
2926 MessageSegment::RedactedThinking(_) => {}
2927 }
2928 }
2929
2930 for tool_use in self.tool_uses_for_message(message.id, cx) {
2931 writeln!(
2932 markdown,
2933 "**Use Tool: {} ({})**",
2934 tool_use.name, tool_use.id
2935 )?;
2936 writeln!(markdown, "```json")?;
2937 writeln!(
2938 markdown,
2939 "{}",
2940 serde_json::to_string_pretty(&tool_use.input)?
2941 )?;
2942 writeln!(markdown, "```")?;
2943 }
2944
2945 for tool_result in self.tool_results_for_message(message.id) {
2946 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2947 if tool_result.is_error {
2948 write!(markdown, " (Error)")?;
2949 }
2950
2951 writeln!(markdown, "**\n")?;
2952 match &tool_result.content {
2953 LanguageModelToolResultContent::Text(text) => {
2954 writeln!(markdown, "{text}")?;
2955 }
2956 LanguageModelToolResultContent::Image(image) => {
2957 writeln!(markdown, "", image.source)?;
2958 }
2959 }
2960
2961 if let Some(output) = tool_result.output.as_ref() {
2962 writeln!(
2963 markdown,
2964 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2965 serde_json::to_string_pretty(output)?
2966 )?;
2967 }
2968 }
2969 }
2970
2971 Ok(String::from_utf8_lossy(&markdown).to_string())
2972 }
2973
2974 pub fn keep_edits_in_range(
2975 &mut self,
2976 buffer: Entity<language::Buffer>,
2977 buffer_range: Range<language::Anchor>,
2978 cx: &mut Context<Self>,
2979 ) {
2980 self.action_log.update(cx, |action_log, cx| {
2981 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2982 });
2983 }
2984
2985 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2986 self.action_log
2987 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2988 }
2989
2990 pub fn reject_edits_in_ranges(
2991 &mut self,
2992 buffer: Entity<language::Buffer>,
2993 buffer_ranges: Vec<Range<language::Anchor>>,
2994 cx: &mut Context<Self>,
2995 ) -> Task<Result<()>> {
2996 self.action_log.update(cx, |action_log, cx| {
2997 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2998 })
2999 }
3000
3001 pub fn action_log(&self) -> &Entity<ActionLog> {
3002 &self.action_log
3003 }
3004
3005 pub fn project(&self) -> &Entity<Project> {
3006 &self.project
3007 }
3008
3009 pub fn cumulative_token_usage(&self) -> TokenUsage {
3010 self.cumulative_token_usage
3011 }
3012
3013 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3014 let Some(model) = self.configured_model.as_ref() else {
3015 return TotalTokenUsage::default();
3016 };
3017
3018 let max = model
3019 .model
3020 .max_token_count_for_mode(self.completion_mode().into());
3021
3022 let index = self
3023 .messages
3024 .iter()
3025 .position(|msg| msg.id == message_id)
3026 .unwrap_or(0);
3027
3028 if index == 0 {
3029 return TotalTokenUsage { total: 0, max };
3030 }
3031
3032 let token_usage = &self
3033 .request_token_usage
3034 .get(index - 1)
3035 .cloned()
3036 .unwrap_or_default();
3037
3038 TotalTokenUsage {
3039 total: token_usage.total_tokens(),
3040 max,
3041 }
3042 }
3043
3044 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3045 let model = self.configured_model.as_ref()?;
3046
3047 let max = model
3048 .model
3049 .max_token_count_for_mode(self.completion_mode().into());
3050
3051 if let Some(exceeded_error) = &self.exceeded_window_error
3052 && model.model.id() == exceeded_error.model_id
3053 {
3054 return Some(TotalTokenUsage {
3055 total: exceeded_error.token_count,
3056 max,
3057 });
3058 }
3059
3060 let total = self
3061 .token_usage_at_last_message()
3062 .unwrap_or_default()
3063 .total_tokens();
3064
3065 Some(TotalTokenUsage { total, max })
3066 }
3067
3068 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3069 self.request_token_usage
3070 .get(self.messages.len().saturating_sub(1))
3071 .or_else(|| self.request_token_usage.last())
3072 .cloned()
3073 }
3074
3075 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3076 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3077 self.request_token_usage
3078 .resize(self.messages.len(), placeholder);
3079
3080 if let Some(last) = self.request_token_usage.last_mut() {
3081 *last = token_usage;
3082 }
3083 }
3084
3085 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3086 self.project
3087 .read(cx)
3088 .user_store()
3089 .update(cx, |user_store, cx| {
3090 user_store.update_model_request_usage(
3091 ModelRequestUsage(RequestUsage {
3092 amount: amount as i32,
3093 limit,
3094 }),
3095 cx,
3096 )
3097 });
3098 }
3099
3100 pub fn deny_tool_use(
3101 &mut self,
3102 tool_use_id: LanguageModelToolUseId,
3103 tool_name: Arc<str>,
3104 window: Option<AnyWindowHandle>,
3105 cx: &mut Context<Self>,
3106 ) {
3107 let err = Err(anyhow::anyhow!(
3108 "Permission to run tool action denied by user"
3109 ));
3110
3111 self.tool_use.insert_tool_output(
3112 tool_use_id.clone(),
3113 tool_name,
3114 err,
3115 self.configured_model.as_ref(),
3116 self.completion_mode,
3117 );
3118 self.tool_finished(tool_use_id, None, true, window, cx);
3119 }
3120}
3121
3122#[derive(Debug, Clone, Error)]
3123pub enum ThreadError {
3124 #[error("Payment required")]
3125 PaymentRequired,
3126 #[error("Model request limit reached")]
3127 ModelRequestLimitReached { plan: Plan },
3128 #[error("Message {header}: {message}")]
3129 Message {
3130 header: SharedString,
3131 message: SharedString,
3132 },
3133 #[error("Retryable error: {message}")]
3134 RetryableError {
3135 message: SharedString,
3136 can_enable_burn_mode: bool,
3137 },
3138}
3139
3140#[derive(Debug, Clone)]
3141pub enum ThreadEvent {
3142 ShowError(ThreadError),
3143 StreamedCompletion,
3144 ReceivedTextChunk,
3145 NewRequest,
3146 StreamedAssistantText(MessageId, String),
3147 StreamedAssistantThinking(MessageId, String),
3148 StreamedToolUse {
3149 tool_use_id: LanguageModelToolUseId,
3150 ui_text: Arc<str>,
3151 input: serde_json::Value,
3152 },
3153 MissingToolUse {
3154 tool_use_id: LanguageModelToolUseId,
3155 ui_text: Arc<str>,
3156 },
3157 InvalidToolInput {
3158 tool_use_id: LanguageModelToolUseId,
3159 ui_text: Arc<str>,
3160 invalid_input_json: Arc<str>,
3161 },
3162 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3163 MessageAdded(MessageId),
3164 MessageEdited(MessageId),
3165 MessageDeleted(MessageId),
3166 SummaryGenerated,
3167 SummaryChanged,
3168 UsePendingTools {
3169 tool_uses: Vec<PendingToolUse>,
3170 },
3171 ToolFinished {
3172 #[allow(unused)]
3173 tool_use_id: LanguageModelToolUseId,
3174 /// The pending tool use that corresponds to this tool.
3175 pending_tool_use: Option<PendingToolUse>,
3176 },
3177 CheckpointChanged,
3178 ToolConfirmationNeeded,
3179 ToolUseLimitReached,
3180 CancelEditing,
3181 CompletionCanceled,
3182 ProfileChanged,
3183}
3184
3185impl EventEmitter<ThreadEvent> for Thread {}
3186
3187struct PendingCompletion {
3188 id: usize,
3189 queue_state: QueueState,
3190 _task: Task<()>,
3191}
3192
3193#[cfg(test)]
3194mod tests {
3195 use super::*;
3196 use crate::{
3197 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3198 };
3199
3200 // Test-specific constants
3201 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3202 use agent_settings::{AgentProfileId, AgentSettings};
3203 use assistant_tool::ToolRegistry;
3204 use assistant_tools;
3205 use fs::Fs;
3206 use futures::StreamExt;
3207 use futures::future::BoxFuture;
3208 use futures::stream::BoxStream;
3209 use gpui::TestAppContext;
3210 use http_client;
3211 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3212 use language_model::{
3213 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3214 LanguageModelProviderName, LanguageModelToolChoice,
3215 };
3216 use parking_lot::Mutex;
3217 use project::{FakeFs, Project};
3218 use prompt_store::PromptBuilder;
3219 use serde_json::json;
3220 use settings::{LanguageModelParameters, Settings, SettingsStore};
3221 use std::sync::Arc;
3222 use std::time::Duration;
3223 use util::path;
3224 use workspace::Workspace;
3225
3226 #[gpui::test]
3227 async fn test_message_with_context(cx: &mut TestAppContext) {
3228 let fs = init_test_settings(cx);
3229
3230 let project = create_test_project(
3231 &fs,
3232 cx,
3233 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3234 )
3235 .await;
3236
3237 let (_workspace, _thread_store, thread, context_store, model) =
3238 setup_test_environment(cx, project.clone()).await;
3239
3240 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3241 .await
3242 .unwrap();
3243
3244 let context =
3245 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3246 let loaded_context = cx
3247 .update(|cx| load_context(vec![context], &project, &None, cx))
3248 .await;
3249
3250 // Insert user message with context
3251 let message_id = thread.update(cx, |thread, cx| {
3252 thread.insert_user_message(
3253 "Please explain this code",
3254 loaded_context,
3255 None,
3256 Vec::new(),
3257 cx,
3258 )
3259 });
3260
3261 // Check content and context in message object
3262 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3263
3264 // Use different path format strings based on platform for the test
3265 #[cfg(windows)]
3266 let path_part = r"test\code.rs";
3267 #[cfg(not(windows))]
3268 let path_part = "test/code.rs";
3269
3270 let expected_context = format!(
3271 r#"
3272<context>
3273The following items were attached by the user. They are up-to-date and don't need to be re-read.
3274
3275<files>
3276```rs {path_part}
3277fn main() {{
3278 println!("Hello, world!");
3279}}
3280```
3281</files>
3282</context>
3283"#
3284 );
3285
3286 assert_eq!(message.role, Role::User);
3287 assert_eq!(message.segments.len(), 1);
3288 assert_eq!(
3289 message.segments[0],
3290 MessageSegment::Text("Please explain this code".to_string())
3291 );
3292 assert_eq!(message.loaded_context.text, expected_context);
3293
3294 // Check message in request
3295 let request = thread.update(cx, |thread, cx| {
3296 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3297 });
3298
3299 assert_eq!(request.messages.len(), 2);
3300 let expected_full_message = format!("{}Please explain this code", expected_context);
3301 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3302 }
3303
3304 #[gpui::test]
3305 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3306 let fs = init_test_settings(cx);
3307
3308 let project = create_test_project(
3309 &fs,
3310 cx,
3311 json!({
3312 "file1.rs": "fn function1() {}\n",
3313 "file2.rs": "fn function2() {}\n",
3314 "file3.rs": "fn function3() {}\n",
3315 "file4.rs": "fn function4() {}\n",
3316 }),
3317 )
3318 .await;
3319
3320 let (_, _thread_store, thread, context_store, model) =
3321 setup_test_environment(cx, project.clone()).await;
3322
3323 // First message with context 1
3324 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3325 .await
3326 .unwrap();
3327 let new_contexts = context_store.update(cx, |store, cx| {
3328 store.new_context_for_thread(thread.read(cx), None)
3329 });
3330 assert_eq!(new_contexts.len(), 1);
3331 let loaded_context = cx
3332 .update(|cx| load_context(new_contexts, &project, &None, cx))
3333 .await;
3334 let message1_id = thread.update(cx, |thread, cx| {
3335 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3336 });
3337
3338 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3339 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3340 .await
3341 .unwrap();
3342 let new_contexts = context_store.update(cx, |store, cx| {
3343 store.new_context_for_thread(thread.read(cx), None)
3344 });
3345 assert_eq!(new_contexts.len(), 1);
3346 let loaded_context = cx
3347 .update(|cx| load_context(new_contexts, &project, &None, cx))
3348 .await;
3349 let message2_id = thread.update(cx, |thread, cx| {
3350 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3351 });
3352
3353 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3354 //
3355 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3356 .await
3357 .unwrap();
3358 let new_contexts = context_store.update(cx, |store, cx| {
3359 store.new_context_for_thread(thread.read(cx), None)
3360 });
3361 assert_eq!(new_contexts.len(), 1);
3362 let loaded_context = cx
3363 .update(|cx| load_context(new_contexts, &project, &None, cx))
3364 .await;
3365 let message3_id = thread.update(cx, |thread, cx| {
3366 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3367 });
3368
3369 // Check what contexts are included in each message
3370 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3371 (
3372 thread.message(message1_id).unwrap().clone(),
3373 thread.message(message2_id).unwrap().clone(),
3374 thread.message(message3_id).unwrap().clone(),
3375 )
3376 });
3377
3378 // First message should include context 1
3379 assert!(message1.loaded_context.text.contains("file1.rs"));
3380
3381 // Second message should include only context 2 (not 1)
3382 assert!(!message2.loaded_context.text.contains("file1.rs"));
3383 assert!(message2.loaded_context.text.contains("file2.rs"));
3384
3385 // Third message should include only context 3 (not 1 or 2)
3386 assert!(!message3.loaded_context.text.contains("file1.rs"));
3387 assert!(!message3.loaded_context.text.contains("file2.rs"));
3388 assert!(message3.loaded_context.text.contains("file3.rs"));
3389
3390 // Check entire request to make sure all contexts are properly included
3391 let request = thread.update(cx, |thread, cx| {
3392 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3393 });
3394
3395 // The request should contain all 3 messages
3396 assert_eq!(request.messages.len(), 4);
3397
3398 // Check that the contexts are properly formatted in each message
3399 assert!(request.messages[1].string_contents().contains("file1.rs"));
3400 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3401 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3402
3403 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3404 assert!(request.messages[2].string_contents().contains("file2.rs"));
3405 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3406
3407 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3408 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3409 assert!(request.messages[3].string_contents().contains("file3.rs"));
3410
3411 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3412 .await
3413 .unwrap();
3414 let new_contexts = context_store.update(cx, |store, cx| {
3415 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3416 });
3417 assert_eq!(new_contexts.len(), 3);
3418 let loaded_context = cx
3419 .update(|cx| load_context(new_contexts, &project, &None, cx))
3420 .await
3421 .loaded_context;
3422
3423 assert!(!loaded_context.text.contains("file1.rs"));
3424 assert!(loaded_context.text.contains("file2.rs"));
3425 assert!(loaded_context.text.contains("file3.rs"));
3426 assert!(loaded_context.text.contains("file4.rs"));
3427
3428 let new_contexts = context_store.update(cx, |store, cx| {
3429 // Remove file4.rs
3430 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3431 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3432 });
3433 assert_eq!(new_contexts.len(), 2);
3434 let loaded_context = cx
3435 .update(|cx| load_context(new_contexts, &project, &None, cx))
3436 .await
3437 .loaded_context;
3438
3439 assert!(!loaded_context.text.contains("file1.rs"));
3440 assert!(loaded_context.text.contains("file2.rs"));
3441 assert!(loaded_context.text.contains("file3.rs"));
3442 assert!(!loaded_context.text.contains("file4.rs"));
3443
3444 let new_contexts = context_store.update(cx, |store, cx| {
3445 // Remove file3.rs
3446 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3447 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3448 });
3449 assert_eq!(new_contexts.len(), 1);
3450 let loaded_context = cx
3451 .update(|cx| load_context(new_contexts, &project, &None, cx))
3452 .await
3453 .loaded_context;
3454
3455 assert!(!loaded_context.text.contains("file1.rs"));
3456 assert!(loaded_context.text.contains("file2.rs"));
3457 assert!(!loaded_context.text.contains("file3.rs"));
3458 assert!(!loaded_context.text.contains("file4.rs"));
3459 }
3460
3461 #[gpui::test]
3462 async fn test_message_without_files(cx: &mut TestAppContext) {
3463 let fs = init_test_settings(cx);
3464
3465 let project = create_test_project(
3466 &fs,
3467 cx,
3468 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3469 )
3470 .await;
3471
3472 let (_, _thread_store, thread, _context_store, model) =
3473 setup_test_environment(cx, project.clone()).await;
3474
3475 // Insert user message without any context (empty context vector)
3476 let message_id = thread.update(cx, |thread, cx| {
3477 thread.insert_user_message(
3478 "What is the best way to learn Rust?",
3479 ContextLoadResult::default(),
3480 None,
3481 Vec::new(),
3482 cx,
3483 )
3484 });
3485
3486 // Check content and context in message object
3487 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3488
3489 // Context should be empty when no files are included
3490 assert_eq!(message.role, Role::User);
3491 assert_eq!(message.segments.len(), 1);
3492 assert_eq!(
3493 message.segments[0],
3494 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3495 );
3496 assert_eq!(message.loaded_context.text, "");
3497
3498 // Check message in request
3499 let request = thread.update(cx, |thread, cx| {
3500 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3501 });
3502
3503 assert_eq!(request.messages.len(), 2);
3504 assert_eq!(
3505 request.messages[1].string_contents(),
3506 "What is the best way to learn Rust?"
3507 );
3508
3509 // Add second message, also without context
3510 let message2_id = thread.update(cx, |thread, cx| {
3511 thread.insert_user_message(
3512 "Are there any good books?",
3513 ContextLoadResult::default(),
3514 None,
3515 Vec::new(),
3516 cx,
3517 )
3518 });
3519
3520 let message2 =
3521 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3522 assert_eq!(message2.loaded_context.text, "");
3523
3524 // Check that both messages appear in the request
3525 let request = thread.update(cx, |thread, cx| {
3526 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3527 });
3528
3529 assert_eq!(request.messages.len(), 3);
3530 assert_eq!(
3531 request.messages[1].string_contents(),
3532 "What is the best way to learn Rust?"
3533 );
3534 assert_eq!(
3535 request.messages[2].string_contents(),
3536 "Are there any good books?"
3537 );
3538 }
3539
3540 #[gpui::test]
3541 #[ignore] // turn this test on when project_notifications tool is re-enabled
3542 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3543 let fs = init_test_settings(cx);
3544
3545 let project = create_test_project(
3546 &fs,
3547 cx,
3548 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3549 )
3550 .await;
3551
3552 let (_workspace, _thread_store, thread, context_store, model) =
3553 setup_test_environment(cx, project.clone()).await;
3554
3555 // Add a buffer to the context. This will be a tracked buffer
3556 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3557 .await
3558 .unwrap();
3559
3560 let context = context_store
3561 .read_with(cx, |store, _| store.context().next().cloned())
3562 .unwrap();
3563 let loaded_context = cx
3564 .update(|cx| load_context(vec![context], &project, &None, cx))
3565 .await;
3566
3567 // Insert user message and assistant response
3568 thread.update(cx, |thread, cx| {
3569 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3570 thread.insert_assistant_message(
3571 vec![MessageSegment::Text("This code prints 42.".into())],
3572 cx,
3573 );
3574 });
3575 cx.run_until_parked();
3576
3577 // We shouldn't have a stale buffer notification yet
3578 let notifications = thread.read_with(cx, |thread, _| {
3579 find_tool_uses(thread, "project_notifications")
3580 });
3581 assert!(
3582 notifications.is_empty(),
3583 "Should not have stale buffer notification before buffer is modified"
3584 );
3585
3586 // Modify the buffer
3587 buffer.update(cx, |buffer, cx| {
3588 buffer.edit(
3589 [(1..1, "\n println!(\"Added a new line\");\n")],
3590 None,
3591 cx,
3592 );
3593 });
3594
3595 // Insert another user message
3596 thread.update(cx, |thread, cx| {
3597 thread.insert_user_message(
3598 "What does the code do now?",
3599 ContextLoadResult::default(),
3600 None,
3601 Vec::new(),
3602 cx,
3603 )
3604 });
3605 cx.run_until_parked();
3606
3607 // Check for the stale buffer warning
3608 thread.update(cx, |thread, cx| {
3609 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3610 });
3611 cx.run_until_parked();
3612
3613 let notifications = thread.read_with(cx, |thread, _cx| {
3614 find_tool_uses(thread, "project_notifications")
3615 });
3616
3617 let [notification] = notifications.as_slice() else {
3618 panic!("Should have a `project_notifications` tool use");
3619 };
3620
3621 let Some(notification_content) = notification.content.to_str() else {
3622 panic!("`project_notifications` should return text");
3623 };
3624
3625 assert!(notification_content.contains("These files have changed since the last read:"));
3626 assert!(notification_content.contains("code.rs"));
3627
3628 // Insert another user message and flush notifications again
3629 thread.update(cx, |thread, cx| {
3630 thread.insert_user_message(
3631 "Can you tell me more?",
3632 ContextLoadResult::default(),
3633 None,
3634 Vec::new(),
3635 cx,
3636 )
3637 });
3638
3639 thread.update(cx, |thread, cx| {
3640 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3641 });
3642 cx.run_until_parked();
3643
3644 // There should be no new notifications (we already flushed one)
3645 let notifications = thread.read_with(cx, |thread, _cx| {
3646 find_tool_uses(thread, "project_notifications")
3647 });
3648
3649 assert_eq!(
3650 notifications.len(),
3651 1,
3652 "Should still have only one notification after second flush - no duplicates"
3653 );
3654 }
3655
3656 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3657 thread
3658 .messages()
3659 .flat_map(|message| {
3660 thread
3661 .tool_results_for_message(message.id)
3662 .into_iter()
3663 .filter(|result| result.tool_name == tool_name.into())
3664 .cloned()
3665 .collect::<Vec<_>>()
3666 })
3667 .collect()
3668 }
3669
3670 #[gpui::test]
3671 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3672 let fs = init_test_settings(cx);
3673
3674 let project = create_test_project(
3675 &fs,
3676 cx,
3677 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3678 )
3679 .await;
3680
3681 let (_workspace, thread_store, thread, _context_store, _model) =
3682 setup_test_environment(cx, project.clone()).await;
3683
3684 // Check that we are starting with the default profile
3685 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3686 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3687 assert_eq!(
3688 profile,
3689 AgentProfile::new(AgentProfileId::default(), tool_set)
3690 );
3691 }
3692
3693 #[gpui::test]
3694 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3695 let fs = init_test_settings(cx);
3696
3697 let project = create_test_project(
3698 &fs,
3699 cx,
3700 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3701 )
3702 .await;
3703
3704 let (_workspace, thread_store, thread, _context_store, _model) =
3705 setup_test_environment(cx, project.clone()).await;
3706
3707 // Profile gets serialized with default values
3708 let serialized = thread
3709 .update(cx, |thread, cx| thread.serialize(cx))
3710 .await
3711 .unwrap();
3712
3713 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3714
3715 let deserialized = cx.update(|cx| {
3716 thread.update(cx, |thread, cx| {
3717 Thread::deserialize(
3718 thread.id.clone(),
3719 serialized,
3720 thread.project.clone(),
3721 thread.tools.clone(),
3722 thread.prompt_builder.clone(),
3723 thread.project_context.clone(),
3724 None,
3725 cx,
3726 )
3727 })
3728 });
3729 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3730
3731 assert_eq!(
3732 deserialized.profile,
3733 AgentProfile::new(AgentProfileId::default(), tool_set)
3734 );
3735 }
3736
3737 #[gpui::test]
3738 async fn test_temperature_setting(cx: &mut TestAppContext) {
3739 let fs = init_test_settings(cx);
3740
3741 let project = create_test_project(
3742 &fs,
3743 cx,
3744 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3745 )
3746 .await;
3747
3748 let (_workspace, _thread_store, thread, _context_store, model) =
3749 setup_test_environment(cx, project.clone()).await;
3750
3751 // Both model and provider
3752 cx.update(|cx| {
3753 AgentSettings::override_global(
3754 AgentSettings {
3755 model_parameters: vec![LanguageModelParameters {
3756 provider: Some(model.provider_id().0.to_string().into()),
3757 model: Some(model.id().0),
3758 temperature: Some(0.66),
3759 }],
3760 ..AgentSettings::get_global(cx).clone()
3761 },
3762 cx,
3763 );
3764 });
3765
3766 let request = thread.update(cx, |thread, cx| {
3767 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3768 });
3769 assert_eq!(request.temperature, Some(0.66));
3770
3771 // Only model
3772 cx.update(|cx| {
3773 AgentSettings::override_global(
3774 AgentSettings {
3775 model_parameters: vec![LanguageModelParameters {
3776 provider: None,
3777 model: Some(model.id().0),
3778 temperature: Some(0.66),
3779 }],
3780 ..AgentSettings::get_global(cx).clone()
3781 },
3782 cx,
3783 );
3784 });
3785
3786 let request = thread.update(cx, |thread, cx| {
3787 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3788 });
3789 assert_eq!(request.temperature, Some(0.66));
3790
3791 // Only provider
3792 cx.update(|cx| {
3793 AgentSettings::override_global(
3794 AgentSettings {
3795 model_parameters: vec![LanguageModelParameters {
3796 provider: Some(model.provider_id().0.to_string().into()),
3797 model: None,
3798 temperature: Some(0.66),
3799 }],
3800 ..AgentSettings::get_global(cx).clone()
3801 },
3802 cx,
3803 );
3804 });
3805
3806 let request = thread.update(cx, |thread, cx| {
3807 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3808 });
3809 assert_eq!(request.temperature, Some(0.66));
3810
3811 // Same model name, different provider
3812 cx.update(|cx| {
3813 AgentSettings::override_global(
3814 AgentSettings {
3815 model_parameters: vec![LanguageModelParameters {
3816 provider: Some("anthropic".into()),
3817 model: Some(model.id().0),
3818 temperature: Some(0.66),
3819 }],
3820 ..AgentSettings::get_global(cx).clone()
3821 },
3822 cx,
3823 );
3824 });
3825
3826 let request = thread.update(cx, |thread, cx| {
3827 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3828 });
3829 assert_eq!(request.temperature, None);
3830 }
3831
3832 #[gpui::test]
3833 async fn test_thread_summary(cx: &mut TestAppContext) {
3834 let fs = init_test_settings(cx);
3835
3836 let project = create_test_project(&fs, cx, json!({})).await;
3837
3838 let (_, _thread_store, thread, _context_store, model) =
3839 setup_test_environment(cx, project.clone()).await;
3840
3841 // Initial state should be pending
3842 thread.read_with(cx, |thread, _| {
3843 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3844 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3845 });
3846
3847 // Manually setting the summary should not be allowed in this state
3848 thread.update(cx, |thread, cx| {
3849 thread.set_summary("This should not work", cx);
3850 });
3851
3852 thread.read_with(cx, |thread, _| {
3853 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3854 });
3855
3856 // Send a message
3857 thread.update(cx, |thread, cx| {
3858 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3859 thread.send_to_model(
3860 model.clone(),
3861 CompletionIntent::ThreadSummarization,
3862 None,
3863 cx,
3864 );
3865 });
3866
3867 let fake_model = model.as_fake();
3868 simulate_successful_response(fake_model, cx);
3869
3870 // Should start generating summary when there are >= 2 messages
3871 thread.read_with(cx, |thread, _| {
3872 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3873 });
3874
3875 // Should not be able to set the summary while generating
3876 thread.update(cx, |thread, cx| {
3877 thread.set_summary("This should not work either", cx);
3878 });
3879
3880 thread.read_with(cx, |thread, _| {
3881 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3882 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3883 });
3884
3885 cx.run_until_parked();
3886 fake_model.send_last_completion_stream_text_chunk("Brief");
3887 fake_model.send_last_completion_stream_text_chunk(" Introduction");
3888 fake_model.end_last_completion_stream();
3889 cx.run_until_parked();
3890
3891 // Summary should be set
3892 thread.read_with(cx, |thread, _| {
3893 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3894 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3895 });
3896
3897 // Now we should be able to set a summary
3898 thread.update(cx, |thread, cx| {
3899 thread.set_summary("Brief Intro", cx);
3900 });
3901
3902 thread.read_with(cx, |thread, _| {
3903 assert_eq!(thread.summary().or_default(), "Brief Intro");
3904 });
3905
3906 // Test setting an empty summary (should default to DEFAULT)
3907 thread.update(cx, |thread, cx| {
3908 thread.set_summary("", cx);
3909 });
3910
3911 thread.read_with(cx, |thread, _| {
3912 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3913 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3914 });
3915 }
3916
3917 #[gpui::test]
3918 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3919 let fs = init_test_settings(cx);
3920
3921 let project = create_test_project(&fs, cx, json!({})).await;
3922
3923 let (_, _thread_store, thread, _context_store, model) =
3924 setup_test_environment(cx, project.clone()).await;
3925
3926 test_summarize_error(&model, &thread, cx);
3927
3928 // Now we should be able to set a summary
3929 thread.update(cx, |thread, cx| {
3930 thread.set_summary("Brief Intro", cx);
3931 });
3932
3933 thread.read_with(cx, |thread, _| {
3934 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3935 assert_eq!(thread.summary().or_default(), "Brief Intro");
3936 });
3937 }
3938
3939 #[gpui::test]
3940 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3941 let fs = init_test_settings(cx);
3942
3943 let project = create_test_project(&fs, cx, json!({})).await;
3944
3945 let (_, _thread_store, thread, _context_store, model) =
3946 setup_test_environment(cx, project.clone()).await;
3947
3948 test_summarize_error(&model, &thread, cx);
3949
3950 // Sending another message should not trigger another summarize request
3951 thread.update(cx, |thread, cx| {
3952 thread.insert_user_message(
3953 "How are you?",
3954 ContextLoadResult::default(),
3955 None,
3956 vec![],
3957 cx,
3958 );
3959 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3960 });
3961
3962 let fake_model = model.as_fake();
3963 simulate_successful_response(fake_model, cx);
3964
3965 thread.read_with(cx, |thread, _| {
3966 // State is still Error, not Generating
3967 assert!(matches!(thread.summary(), ThreadSummary::Error));
3968 });
3969
3970 // But the summarize request can be invoked manually
3971 thread.update(cx, |thread, cx| {
3972 thread.summarize(cx);
3973 });
3974
3975 thread.read_with(cx, |thread, _| {
3976 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3977 });
3978
3979 cx.run_until_parked();
3980 fake_model.send_last_completion_stream_text_chunk("A successful summary");
3981 fake_model.end_last_completion_stream();
3982 cx.run_until_parked();
3983
3984 thread.read_with(cx, |thread, _| {
3985 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3986 assert_eq!(thread.summary().or_default(), "A successful summary");
3987 });
3988 }
3989
3990 // Helper to create a model that returns errors
3991 enum TestError {
3992 Overloaded,
3993 InternalServerError,
3994 }
3995
3996 struct ErrorInjector {
3997 inner: Arc<FakeLanguageModel>,
3998 error_type: TestError,
3999 }
4000
4001 impl ErrorInjector {
4002 fn new(error_type: TestError) -> Self {
4003 Self {
4004 inner: Arc::new(FakeLanguageModel::default()),
4005 error_type,
4006 }
4007 }
4008 }
4009
4010 impl LanguageModel for ErrorInjector {
4011 fn id(&self) -> LanguageModelId {
4012 self.inner.id()
4013 }
4014
4015 fn name(&self) -> LanguageModelName {
4016 self.inner.name()
4017 }
4018
4019 fn provider_id(&self) -> LanguageModelProviderId {
4020 self.inner.provider_id()
4021 }
4022
4023 fn provider_name(&self) -> LanguageModelProviderName {
4024 self.inner.provider_name()
4025 }
4026
4027 fn supports_tools(&self) -> bool {
4028 self.inner.supports_tools()
4029 }
4030
4031 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4032 self.inner.supports_tool_choice(choice)
4033 }
4034
4035 fn supports_images(&self) -> bool {
4036 self.inner.supports_images()
4037 }
4038
4039 fn telemetry_id(&self) -> String {
4040 self.inner.telemetry_id()
4041 }
4042
4043 fn max_token_count(&self) -> u64 {
4044 self.inner.max_token_count()
4045 }
4046
4047 fn count_tokens(
4048 &self,
4049 request: LanguageModelRequest,
4050 cx: &App,
4051 ) -> BoxFuture<'static, Result<u64>> {
4052 self.inner.count_tokens(request, cx)
4053 }
4054
4055 fn stream_completion(
4056 &self,
4057 _request: LanguageModelRequest,
4058 _cx: &AsyncApp,
4059 ) -> BoxFuture<
4060 'static,
4061 Result<
4062 BoxStream<
4063 'static,
4064 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4065 >,
4066 LanguageModelCompletionError,
4067 >,
4068 > {
4069 let error = match self.error_type {
4070 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4071 provider: self.provider_name(),
4072 retry_after: None,
4073 },
4074 TestError::InternalServerError => {
4075 LanguageModelCompletionError::ApiInternalServerError {
4076 provider: self.provider_name(),
4077 message: "I'm a teapot orbiting the sun".to_string(),
4078 }
4079 }
4080 };
4081 async move {
4082 let stream = futures::stream::once(async move { Err(error) });
4083 Ok(stream.boxed())
4084 }
4085 .boxed()
4086 }
4087
4088 fn as_fake(&self) -> &FakeLanguageModel {
4089 &self.inner
4090 }
4091 }
4092
4093 #[gpui::test]
4094 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4095 let fs = init_test_settings(cx);
4096
4097 let project = create_test_project(&fs, cx, json!({})).await;
4098 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4099
4100 // Enable Burn Mode to allow retries
4101 thread.update(cx, |thread, _| {
4102 thread.set_completion_mode(CompletionMode::Burn);
4103 });
4104
4105 // Create model that returns overloaded error
4106 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4107
4108 // Insert a user message
4109 thread.update(cx, |thread, cx| {
4110 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4111 });
4112
4113 // Start completion
4114 thread.update(cx, |thread, cx| {
4115 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4116 });
4117
4118 cx.run_until_parked();
4119
4120 thread.read_with(cx, |thread, _| {
4121 assert!(thread.retry_state.is_some(), "Should have retry state");
4122 let retry_state = thread.retry_state.as_ref().unwrap();
4123 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4124 assert_eq!(
4125 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4126 "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4127 );
4128 });
4129
4130 // Check that a retry message was added
4131 thread.read_with(cx, |thread, _| {
4132 let mut messages = thread.messages();
4133 assert!(
4134 messages.any(|msg| {
4135 msg.role == Role::System
4136 && msg.ui_only
4137 && msg.segments.iter().any(|seg| {
4138 if let MessageSegment::Text(text) = seg {
4139 text.contains("overloaded")
4140 && text
4141 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4142 } else {
4143 false
4144 }
4145 })
4146 }),
4147 "Should have added a system retry message"
4148 );
4149 });
4150
4151 let retry_count = thread.update(cx, |thread, _| {
4152 thread
4153 .messages
4154 .iter()
4155 .filter(|m| {
4156 m.ui_only
4157 && m.segments.iter().any(|s| {
4158 if let MessageSegment::Text(text) = s {
4159 text.contains("Retrying") && text.contains("seconds")
4160 } else {
4161 false
4162 }
4163 })
4164 })
4165 .count()
4166 });
4167
4168 assert_eq!(retry_count, 1, "Should have one retry message");
4169 }
4170
4171 #[gpui::test]
4172 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4173 let fs = init_test_settings(cx);
4174
4175 let project = create_test_project(&fs, cx, json!({})).await;
4176 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4177
4178 // Enable Burn Mode to allow retries
4179 thread.update(cx, |thread, _| {
4180 thread.set_completion_mode(CompletionMode::Burn);
4181 });
4182
4183 // Create model that returns internal server error
4184 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4185
4186 // Insert a user message
4187 thread.update(cx, |thread, cx| {
4188 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4189 });
4190
4191 // Start completion
4192 thread.update(cx, |thread, cx| {
4193 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4194 });
4195
4196 cx.run_until_parked();
4197
4198 // Check retry state on thread
4199 thread.read_with(cx, |thread, _| {
4200 assert!(thread.retry_state.is_some(), "Should have retry state");
4201 let retry_state = thread.retry_state.as_ref().unwrap();
4202 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4203 assert_eq!(
4204 retry_state.max_attempts, 3,
4205 "Should have correct max attempts"
4206 );
4207 });
4208
4209 // Check that a retry message was added with provider name
4210 thread.read_with(cx, |thread, _| {
4211 let mut messages = thread.messages();
4212 assert!(
4213 messages.any(|msg| {
4214 msg.role == Role::System
4215 && msg.ui_only
4216 && msg.segments.iter().any(|seg| {
4217 if let MessageSegment::Text(text) = seg {
4218 text.contains("internal")
4219 && text.contains("Fake")
4220 && text.contains("Retrying")
4221 && text.contains("attempt 1 of 3")
4222 && text.contains("seconds")
4223 } else {
4224 false
4225 }
4226 })
4227 }),
4228 "Should have added a system retry message with provider name"
4229 );
4230 });
4231
4232 // Count retry messages
4233 let retry_count = thread.update(cx, |thread, _| {
4234 thread
4235 .messages
4236 .iter()
4237 .filter(|m| {
4238 m.ui_only
4239 && m.segments.iter().any(|s| {
4240 if let MessageSegment::Text(text) = s {
4241 text.contains("Retrying") && text.contains("seconds")
4242 } else {
4243 false
4244 }
4245 })
4246 })
4247 .count()
4248 });
4249
4250 assert_eq!(retry_count, 1, "Should have one retry message");
4251 }
4252
4253 #[gpui::test]
4254 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4255 let fs = init_test_settings(cx);
4256
4257 let project = create_test_project(&fs, cx, json!({})).await;
4258 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4259
4260 // Enable Burn Mode to allow retries
4261 thread.update(cx, |thread, _| {
4262 thread.set_completion_mode(CompletionMode::Burn);
4263 });
4264
4265 // Create model that returns internal server error
4266 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4267
4268 // Insert a user message
4269 thread.update(cx, |thread, cx| {
4270 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4271 });
4272
4273 // Track retry events and completion count
4274 // Track completion events
4275 let completion_count = Arc::new(Mutex::new(0));
4276 let completion_count_clone = completion_count.clone();
4277
4278 let _subscription = thread.update(cx, |_, cx| {
4279 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4280 if let ThreadEvent::NewRequest = event {
4281 *completion_count_clone.lock() += 1;
4282 }
4283 })
4284 });
4285
4286 // First attempt
4287 thread.update(cx, |thread, cx| {
4288 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4289 });
4290 cx.run_until_parked();
4291
4292 // Should have scheduled first retry - count retry messages
4293 let retry_count = thread.update(cx, |thread, _| {
4294 thread
4295 .messages
4296 .iter()
4297 .filter(|m| {
4298 m.ui_only
4299 && m.segments.iter().any(|s| {
4300 if let MessageSegment::Text(text) = s {
4301 text.contains("Retrying") && text.contains("seconds")
4302 } else {
4303 false
4304 }
4305 })
4306 })
4307 .count()
4308 });
4309 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4310
4311 // Check retry state
4312 thread.read_with(cx, |thread, _| {
4313 assert!(thread.retry_state.is_some(), "Should have retry state");
4314 let retry_state = thread.retry_state.as_ref().unwrap();
4315 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4316 assert_eq!(
4317 retry_state.max_attempts, 3,
4318 "Internal server errors should retry up to 3 times"
4319 );
4320 });
4321
4322 // Advance clock for first retry
4323 cx.executor().advance_clock(BASE_RETRY_DELAY);
4324 cx.run_until_parked();
4325
4326 // Advance clock for second retry
4327 cx.executor().advance_clock(BASE_RETRY_DELAY);
4328 cx.run_until_parked();
4329
4330 // Advance clock for third retry
4331 cx.executor().advance_clock(BASE_RETRY_DELAY);
4332 cx.run_until_parked();
4333
4334 // Should have completed all retries - count retry messages
4335 let retry_count = thread.update(cx, |thread, _| {
4336 thread
4337 .messages
4338 .iter()
4339 .filter(|m| {
4340 m.ui_only
4341 && m.segments.iter().any(|s| {
4342 if let MessageSegment::Text(text) = s {
4343 text.contains("Retrying") && text.contains("seconds")
4344 } else {
4345 false
4346 }
4347 })
4348 })
4349 .count()
4350 });
4351 assert_eq!(
4352 retry_count, 3,
4353 "Should have 3 retries for internal server errors"
4354 );
4355
4356 // For internal server errors, we retry 3 times and then give up
4357 // Check that retry_state is cleared after all retries
4358 thread.read_with(cx, |thread, _| {
4359 assert!(
4360 thread.retry_state.is_none(),
4361 "Retry state should be cleared after all retries"
4362 );
4363 });
4364
4365 // Verify total attempts (1 initial + 3 retries)
4366 assert_eq!(
4367 *completion_count.lock(),
4368 4,
4369 "Should have attempted once plus 3 retries"
4370 );
4371 }
4372
4373 #[gpui::test]
4374 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4375 let fs = init_test_settings(cx);
4376
4377 let project = create_test_project(&fs, cx, json!({})).await;
4378 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4379
4380 // Enable Burn Mode to allow retries
4381 thread.update(cx, |thread, _| {
4382 thread.set_completion_mode(CompletionMode::Burn);
4383 });
4384
4385 // Create model that returns overloaded error
4386 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4387
4388 // Insert a user message
4389 thread.update(cx, |thread, cx| {
4390 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4391 });
4392
4393 // Track events
4394 let stopped_with_error = Arc::new(Mutex::new(false));
4395 let stopped_with_error_clone = stopped_with_error.clone();
4396
4397 let _subscription = thread.update(cx, |_, cx| {
4398 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4399 if let ThreadEvent::Stopped(Err(_)) = event {
4400 *stopped_with_error_clone.lock() = true;
4401 }
4402 })
4403 });
4404
4405 // Start initial completion
4406 thread.update(cx, |thread, cx| {
4407 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4408 });
4409 cx.run_until_parked();
4410
4411 // Advance through all retries
4412 for _ in 0..MAX_RETRY_ATTEMPTS {
4413 cx.executor().advance_clock(BASE_RETRY_DELAY);
4414 cx.run_until_parked();
4415 }
4416
4417 let retry_count = thread.update(cx, |thread, _| {
4418 thread
4419 .messages
4420 .iter()
4421 .filter(|m| {
4422 m.ui_only
4423 && m.segments.iter().any(|s| {
4424 if let MessageSegment::Text(text) = s {
4425 text.contains("Retrying") && text.contains("seconds")
4426 } else {
4427 false
4428 }
4429 })
4430 })
4431 .count()
4432 });
4433
4434 // After max retries, should emit Stopped(Err(...)) event
4435 assert_eq!(
4436 retry_count, MAX_RETRY_ATTEMPTS as usize,
4437 "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4438 );
4439 assert!(
4440 *stopped_with_error.lock(),
4441 "Should emit Stopped(Err(...)) event after max retries exceeded"
4442 );
4443
4444 // Retry state should be cleared
4445 thread.read_with(cx, |thread, _| {
4446 assert!(
4447 thread.retry_state.is_none(),
4448 "Retry state should be cleared after max retries"
4449 );
4450
4451 // Verify we have the expected number of retry messages
4452 let retry_messages = thread
4453 .messages
4454 .iter()
4455 .filter(|msg| msg.ui_only && msg.role == Role::System)
4456 .count();
4457 assert_eq!(
4458 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4459 "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4460 );
4461 });
4462 }
4463
4464 #[gpui::test]
4465 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4466 let fs = init_test_settings(cx);
4467
4468 let project = create_test_project(&fs, cx, json!({})).await;
4469 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4470
4471 // Enable Burn Mode to allow retries
4472 thread.update(cx, |thread, _| {
4473 thread.set_completion_mode(CompletionMode::Burn);
4474 });
4475
4476 // We'll use a wrapper to switch behavior after first failure
4477 struct RetryTestModel {
4478 inner: Arc<FakeLanguageModel>,
4479 failed_once: Arc<Mutex<bool>>,
4480 }
4481
4482 impl LanguageModel for RetryTestModel {
4483 fn id(&self) -> LanguageModelId {
4484 self.inner.id()
4485 }
4486
4487 fn name(&self) -> LanguageModelName {
4488 self.inner.name()
4489 }
4490
4491 fn provider_id(&self) -> LanguageModelProviderId {
4492 self.inner.provider_id()
4493 }
4494
4495 fn provider_name(&self) -> LanguageModelProviderName {
4496 self.inner.provider_name()
4497 }
4498
4499 fn supports_tools(&self) -> bool {
4500 self.inner.supports_tools()
4501 }
4502
4503 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4504 self.inner.supports_tool_choice(choice)
4505 }
4506
4507 fn supports_images(&self) -> bool {
4508 self.inner.supports_images()
4509 }
4510
4511 fn telemetry_id(&self) -> String {
4512 self.inner.telemetry_id()
4513 }
4514
4515 fn max_token_count(&self) -> u64 {
4516 self.inner.max_token_count()
4517 }
4518
4519 fn count_tokens(
4520 &self,
4521 request: LanguageModelRequest,
4522 cx: &App,
4523 ) -> BoxFuture<'static, Result<u64>> {
4524 self.inner.count_tokens(request, cx)
4525 }
4526
4527 fn stream_completion(
4528 &self,
4529 request: LanguageModelRequest,
4530 cx: &AsyncApp,
4531 ) -> BoxFuture<
4532 'static,
4533 Result<
4534 BoxStream<
4535 'static,
4536 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4537 >,
4538 LanguageModelCompletionError,
4539 >,
4540 > {
4541 if !*self.failed_once.lock() {
4542 *self.failed_once.lock() = true;
4543 let provider = self.provider_name();
4544 // Return error on first attempt
4545 let stream = futures::stream::once(async move {
4546 Err(LanguageModelCompletionError::ServerOverloaded {
4547 provider,
4548 retry_after: None,
4549 })
4550 });
4551 async move { Ok(stream.boxed()) }.boxed()
4552 } else {
4553 // Succeed on retry
4554 self.inner.stream_completion(request, cx)
4555 }
4556 }
4557
4558 fn as_fake(&self) -> &FakeLanguageModel {
4559 &self.inner
4560 }
4561 }
4562
4563 let model = Arc::new(RetryTestModel {
4564 inner: Arc::new(FakeLanguageModel::default()),
4565 failed_once: Arc::new(Mutex::new(false)),
4566 });
4567
4568 // Insert a user message
4569 thread.update(cx, |thread, cx| {
4570 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4571 });
4572
4573 // Track message deletions
4574 // Track when retry completes successfully
4575 let retry_completed = Arc::new(Mutex::new(false));
4576 let retry_completed_clone = retry_completed.clone();
4577
4578 let _subscription = thread.update(cx, |_, cx| {
4579 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4580 if let ThreadEvent::StreamedCompletion = event {
4581 *retry_completed_clone.lock() = true;
4582 }
4583 })
4584 });
4585
4586 // Start completion
4587 thread.update(cx, |thread, cx| {
4588 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4589 });
4590 cx.run_until_parked();
4591
4592 // Get the retry message ID
4593 let retry_message_id = thread.read_with(cx, |thread, _| {
4594 thread
4595 .messages()
4596 .find(|msg| msg.role == Role::System && msg.ui_only)
4597 .map(|msg| msg.id)
4598 .expect("Should have a retry message")
4599 });
4600
4601 // Wait for retry
4602 cx.executor().advance_clock(BASE_RETRY_DELAY);
4603 cx.run_until_parked();
4604
4605 // Stream some successful content
4606 let fake_model = model.as_fake();
4607 // After the retry, there should be a new pending completion
4608 let pending = fake_model.pending_completions();
4609 assert!(
4610 !pending.is_empty(),
4611 "Should have a pending completion after retry"
4612 );
4613 fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
4614 fake_model.end_completion_stream(&pending[0]);
4615 cx.run_until_parked();
4616
4617 // Check that the retry completed successfully
4618 assert!(
4619 *retry_completed.lock(),
4620 "Retry should have completed successfully"
4621 );
4622
4623 // Retry message should still exist but be marked as ui_only
4624 thread.read_with(cx, |thread, _| {
4625 let retry_msg = thread
4626 .message(retry_message_id)
4627 .expect("Retry message should still exist");
4628 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4629 assert_eq!(
4630 retry_msg.role,
4631 Role::System,
4632 "Retry message should have System role"
4633 );
4634 });
4635 }
4636
4637 #[gpui::test]
4638 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4639 let fs = init_test_settings(cx);
4640
4641 let project = create_test_project(&fs, cx, json!({})).await;
4642 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4643
4644 // Enable Burn Mode to allow retries
4645 thread.update(cx, |thread, _| {
4646 thread.set_completion_mode(CompletionMode::Burn);
4647 });
4648
4649 // Create a model that fails once then succeeds
4650 struct FailOnceModel {
4651 inner: Arc<FakeLanguageModel>,
4652 failed_once: Arc<Mutex<bool>>,
4653 }
4654
4655 impl LanguageModel for FailOnceModel {
4656 fn id(&self) -> LanguageModelId {
4657 self.inner.id()
4658 }
4659
4660 fn name(&self) -> LanguageModelName {
4661 self.inner.name()
4662 }
4663
4664 fn provider_id(&self) -> LanguageModelProviderId {
4665 self.inner.provider_id()
4666 }
4667
4668 fn provider_name(&self) -> LanguageModelProviderName {
4669 self.inner.provider_name()
4670 }
4671
4672 fn supports_tools(&self) -> bool {
4673 self.inner.supports_tools()
4674 }
4675
4676 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4677 self.inner.supports_tool_choice(choice)
4678 }
4679
4680 fn supports_images(&self) -> bool {
4681 self.inner.supports_images()
4682 }
4683
4684 fn telemetry_id(&self) -> String {
4685 self.inner.telemetry_id()
4686 }
4687
4688 fn max_token_count(&self) -> u64 {
4689 self.inner.max_token_count()
4690 }
4691
4692 fn count_tokens(
4693 &self,
4694 request: LanguageModelRequest,
4695 cx: &App,
4696 ) -> BoxFuture<'static, Result<u64>> {
4697 self.inner.count_tokens(request, cx)
4698 }
4699
4700 fn stream_completion(
4701 &self,
4702 request: LanguageModelRequest,
4703 cx: &AsyncApp,
4704 ) -> BoxFuture<
4705 'static,
4706 Result<
4707 BoxStream<
4708 'static,
4709 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4710 >,
4711 LanguageModelCompletionError,
4712 >,
4713 > {
4714 if !*self.failed_once.lock() {
4715 *self.failed_once.lock() = true;
4716 let provider = self.provider_name();
4717 // Return error on first attempt
4718 let stream = futures::stream::once(async move {
4719 Err(LanguageModelCompletionError::ServerOverloaded {
4720 provider,
4721 retry_after: None,
4722 })
4723 });
4724 async move { Ok(stream.boxed()) }.boxed()
4725 } else {
4726 // Succeed on retry
4727 self.inner.stream_completion(request, cx)
4728 }
4729 }
4730 }
4731
4732 let fail_once_model = Arc::new(FailOnceModel {
4733 inner: Arc::new(FakeLanguageModel::default()),
4734 failed_once: Arc::new(Mutex::new(false)),
4735 });
4736
4737 // Insert a user message
4738 thread.update(cx, |thread, cx| {
4739 thread.insert_user_message(
4740 "Test message",
4741 ContextLoadResult::default(),
4742 None,
4743 vec![],
4744 cx,
4745 );
4746 });
4747
4748 // Start completion with fail-once model
4749 thread.update(cx, |thread, cx| {
4750 thread.send_to_model(
4751 fail_once_model.clone(),
4752 CompletionIntent::UserPrompt,
4753 None,
4754 cx,
4755 );
4756 });
4757
4758 cx.run_until_parked();
4759
4760 // Verify retry state exists after first failure
4761 thread.read_with(cx, |thread, _| {
4762 assert!(
4763 thread.retry_state.is_some(),
4764 "Should have retry state after failure"
4765 );
4766 });
4767
4768 // Wait for retry delay
4769 cx.executor().advance_clock(BASE_RETRY_DELAY);
4770 cx.run_until_parked();
4771
4772 // The retry should now use our FailOnceModel which should succeed
4773 // We need to help the FakeLanguageModel complete the stream
4774 let inner_fake = fail_once_model.inner.clone();
4775
4776 // Wait a bit for the retry to start
4777 cx.run_until_parked();
4778
4779 // Check for pending completions and complete them
4780 if let Some(pending) = inner_fake.pending_completions().first() {
4781 inner_fake.send_completion_stream_text_chunk(pending, "Success!");
4782 inner_fake.end_completion_stream(pending);
4783 }
4784 cx.run_until_parked();
4785
4786 thread.read_with(cx, |thread, _| {
4787 assert!(
4788 thread.retry_state.is_none(),
4789 "Retry state should be cleared after successful completion"
4790 );
4791
4792 let has_assistant_message = thread
4793 .messages
4794 .iter()
4795 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4796 assert!(
4797 has_assistant_message,
4798 "Should have an assistant message after successful retry"
4799 );
4800 });
4801 }
4802
4803 #[gpui::test]
4804 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4805 let fs = init_test_settings(cx);
4806
4807 let project = create_test_project(&fs, cx, json!({})).await;
4808 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4809
4810 // Enable Burn Mode to allow retries
4811 thread.update(cx, |thread, _| {
4812 thread.set_completion_mode(CompletionMode::Burn);
4813 });
4814
4815 // Create a model that returns rate limit error with retry_after
4816 struct RateLimitModel {
4817 inner: Arc<FakeLanguageModel>,
4818 }
4819
4820 impl LanguageModel for RateLimitModel {
4821 fn id(&self) -> LanguageModelId {
4822 self.inner.id()
4823 }
4824
4825 fn name(&self) -> LanguageModelName {
4826 self.inner.name()
4827 }
4828
4829 fn provider_id(&self) -> LanguageModelProviderId {
4830 self.inner.provider_id()
4831 }
4832
4833 fn provider_name(&self) -> LanguageModelProviderName {
4834 self.inner.provider_name()
4835 }
4836
4837 fn supports_tools(&self) -> bool {
4838 self.inner.supports_tools()
4839 }
4840
4841 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4842 self.inner.supports_tool_choice(choice)
4843 }
4844
4845 fn supports_images(&self) -> bool {
4846 self.inner.supports_images()
4847 }
4848
4849 fn telemetry_id(&self) -> String {
4850 self.inner.telemetry_id()
4851 }
4852
4853 fn max_token_count(&self) -> u64 {
4854 self.inner.max_token_count()
4855 }
4856
4857 fn count_tokens(
4858 &self,
4859 request: LanguageModelRequest,
4860 cx: &App,
4861 ) -> BoxFuture<'static, Result<u64>> {
4862 self.inner.count_tokens(request, cx)
4863 }
4864
4865 fn stream_completion(
4866 &self,
4867 _request: LanguageModelRequest,
4868 _cx: &AsyncApp,
4869 ) -> BoxFuture<
4870 'static,
4871 Result<
4872 BoxStream<
4873 'static,
4874 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4875 >,
4876 LanguageModelCompletionError,
4877 >,
4878 > {
4879 let provider = self.provider_name();
4880 async move {
4881 let stream = futures::stream::once(async move {
4882 Err(LanguageModelCompletionError::RateLimitExceeded {
4883 provider,
4884 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4885 })
4886 });
4887 Ok(stream.boxed())
4888 }
4889 .boxed()
4890 }
4891
4892 fn as_fake(&self) -> &FakeLanguageModel {
4893 &self.inner
4894 }
4895 }
4896
4897 let model = Arc::new(RateLimitModel {
4898 inner: Arc::new(FakeLanguageModel::default()),
4899 });
4900
4901 // Insert a user message
4902 thread.update(cx, |thread, cx| {
4903 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4904 });
4905
4906 // Start completion
4907 thread.update(cx, |thread, cx| {
4908 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4909 });
4910
4911 cx.run_until_parked();
4912
4913 let retry_count = thread.update(cx, |thread, _| {
4914 thread
4915 .messages
4916 .iter()
4917 .filter(|m| {
4918 m.ui_only
4919 && m.segments.iter().any(|s| {
4920 if let MessageSegment::Text(text) = s {
4921 text.contains("rate limit exceeded")
4922 } else {
4923 false
4924 }
4925 })
4926 })
4927 .count()
4928 });
4929 assert_eq!(retry_count, 1, "Should have scheduled one retry");
4930
4931 thread.read_with(cx, |thread, _| {
4932 assert!(
4933 thread.retry_state.is_some(),
4934 "Rate limit errors should set retry_state"
4935 );
4936 if let Some(retry_state) = &thread.retry_state {
4937 assert_eq!(
4938 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4939 "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4940 );
4941 }
4942 });
4943
4944 // Verify we have one retry message
4945 thread.read_with(cx, |thread, _| {
4946 let retry_messages = thread
4947 .messages
4948 .iter()
4949 .filter(|msg| {
4950 msg.ui_only
4951 && msg.segments.iter().any(|seg| {
4952 if let MessageSegment::Text(text) = seg {
4953 text.contains("rate limit exceeded")
4954 } else {
4955 false
4956 }
4957 })
4958 })
4959 .count();
4960 assert_eq!(
4961 retry_messages, 1,
4962 "Should have one rate limit retry message"
4963 );
4964 });
4965
4966 // Check that retry message doesn't include attempt count
4967 thread.read_with(cx, |thread, _| {
4968 let retry_message = thread
4969 .messages
4970 .iter()
4971 .find(|msg| msg.role == Role::System && msg.ui_only)
4972 .expect("Should have a retry message");
4973
4974 // Check that the message contains attempt count since we use retry_state
4975 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
4976 assert!(
4977 text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
4978 "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
4979 );
4980 assert!(
4981 text.contains("Retrying"),
4982 "Rate limit retry message should contain retry text"
4983 );
4984 }
4985 });
4986 }
4987
4988 #[gpui::test]
4989 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
4990 let fs = init_test_settings(cx);
4991
4992 let project = create_test_project(&fs, cx, json!({})).await;
4993 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
4994
4995 // Insert a regular user message
4996 thread.update(cx, |thread, cx| {
4997 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4998 });
4999
5000 // Insert a UI-only message (like our retry notifications)
5001 thread.update(cx, |thread, cx| {
5002 let id = thread.next_message_id.post_inc();
5003 thread.messages.push(Message {
5004 id,
5005 role: Role::System,
5006 segments: vec![MessageSegment::Text(
5007 "This is a UI-only message that should not be sent to the model".to_string(),
5008 )],
5009 loaded_context: LoadedContext::default(),
5010 creases: Vec::new(),
5011 is_hidden: true,
5012 ui_only: true,
5013 });
5014 cx.emit(ThreadEvent::MessageAdded(id));
5015 });
5016
5017 // Insert another regular message
5018 thread.update(cx, |thread, cx| {
5019 thread.insert_user_message(
5020 "How are you?",
5021 ContextLoadResult::default(),
5022 None,
5023 vec![],
5024 cx,
5025 );
5026 });
5027
5028 // Generate the completion request
5029 let request = thread.update(cx, |thread, cx| {
5030 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5031 });
5032
5033 // Verify that the request only contains non-UI-only messages
5034 // Should have system prompt + 2 user messages, but not the UI-only message
5035 let user_messages: Vec<_> = request
5036 .messages
5037 .iter()
5038 .filter(|msg| msg.role == Role::User)
5039 .collect();
5040 assert_eq!(
5041 user_messages.len(),
5042 2,
5043 "Should have exactly 2 user messages"
5044 );
5045
5046 // Verify the UI-only content is not present anywhere in the request
5047 let request_text = request
5048 .messages
5049 .iter()
5050 .flat_map(|msg| &msg.content)
5051 .filter_map(|content| match content {
5052 MessageContent::Text(text) => Some(text.as_str()),
5053 _ => None,
5054 })
5055 .collect::<String>();
5056
5057 assert!(
5058 !request_text.contains("UI-only message"),
5059 "UI-only message content should not be in the request"
5060 );
5061
5062 // Verify the thread still has all 3 messages (including UI-only)
5063 thread.read_with(cx, |thread, _| {
5064 assert_eq!(
5065 thread.messages().count(),
5066 3,
5067 "Thread should have 3 messages"
5068 );
5069 assert_eq!(
5070 thread.messages().filter(|m| m.ui_only).count(),
5071 1,
5072 "Thread should have 1 UI-only message"
5073 );
5074 });
5075
5076 // Verify that UI-only messages are not serialized
5077 let serialized = thread
5078 .update(cx, |thread, cx| thread.serialize(cx))
5079 .await
5080 .unwrap();
5081 assert_eq!(
5082 serialized.messages.len(),
5083 2,
5084 "Serialized thread should only have 2 messages (no UI-only)"
5085 );
5086 }
5087
5088 #[gpui::test]
5089 async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
5090 let fs = init_test_settings(cx);
5091
5092 let project = create_test_project(&fs, cx, json!({})).await;
5093 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5094
5095 // Ensure we're in Normal mode (not Burn mode)
5096 thread.update(cx, |thread, _| {
5097 thread.set_completion_mode(CompletionMode::Normal);
5098 });
5099
5100 // Track error events
5101 let error_events = Arc::new(Mutex::new(Vec::new()));
5102 let error_events_clone = error_events.clone();
5103
5104 let _subscription = thread.update(cx, |_, cx| {
5105 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
5106 if let ThreadEvent::ShowError(error) = event {
5107 error_events_clone.lock().push(error.clone());
5108 }
5109 })
5110 });
5111
5112 // Create model that returns overloaded error
5113 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5114
5115 // Insert a user message
5116 thread.update(cx, |thread, cx| {
5117 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5118 });
5119
5120 // Start completion
5121 thread.update(cx, |thread, cx| {
5122 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5123 });
5124
5125 cx.run_until_parked();
5126
5127 // Verify no retry state was created
5128 thread.read_with(cx, |thread, _| {
5129 assert!(
5130 thread.retry_state.is_none(),
5131 "Should not have retry state in Normal mode"
5132 );
5133 });
5134
5135 // Check that a retryable error was reported
5136 let errors = error_events.lock();
5137 assert!(!errors.is_empty(), "Should have received an error event");
5138
5139 if let ThreadError::RetryableError {
5140 message: _,
5141 can_enable_burn_mode,
5142 } = &errors[0]
5143 {
5144 assert!(
5145 *can_enable_burn_mode,
5146 "Error should indicate burn mode can be enabled"
5147 );
5148 } else {
5149 panic!("Expected RetryableError, got {:?}", errors[0]);
5150 }
5151
5152 // Verify the thread is no longer generating
5153 thread.read_with(cx, |thread, _| {
5154 assert!(
5155 !thread.is_generating(),
5156 "Should not be generating after error without retry"
5157 );
5158 });
5159 }
5160
5161 #[gpui::test]
5162 async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
5163 let fs = init_test_settings(cx);
5164
5165 let project = create_test_project(&fs, cx, json!({})).await;
5166 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5167
5168 // Enable Burn Mode to allow retries
5169 thread.update(cx, |thread, _| {
5170 thread.set_completion_mode(CompletionMode::Burn);
5171 });
5172
5173 // Create model that returns overloaded error
5174 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5175
5176 // Insert a user message
5177 thread.update(cx, |thread, cx| {
5178 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5179 });
5180
5181 // Start completion
5182 thread.update(cx, |thread, cx| {
5183 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5184 });
5185
5186 cx.run_until_parked();
5187
5188 // Verify retry was scheduled by checking for retry message
5189 let has_retry_message = thread.read_with(cx, |thread, _| {
5190 thread.messages.iter().any(|m| {
5191 m.ui_only
5192 && m.segments.iter().any(|s| {
5193 if let MessageSegment::Text(text) = s {
5194 text.contains("Retrying") && text.contains("seconds")
5195 } else {
5196 false
5197 }
5198 })
5199 })
5200 });
5201 assert!(has_retry_message, "Should have scheduled a retry");
5202
5203 // Cancel the completion before the retry happens
5204 thread.update(cx, |thread, cx| {
5205 thread.cancel_last_completion(None, cx);
5206 });
5207
5208 cx.run_until_parked();
5209
5210 // The retry should not have happened - no pending completions
5211 let fake_model = model.as_fake();
5212 assert_eq!(
5213 fake_model.pending_completions().len(),
5214 0,
5215 "Should have no pending completions after cancellation"
5216 );
5217
5218 // Verify the retry was canceled by checking retry state
5219 thread.read_with(cx, |thread, _| {
5220 if let Some(retry_state) = &thread.retry_state {
5221 panic!(
5222 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5223 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5224 );
5225 }
5226 });
5227 }
5228
5229 fn test_summarize_error(
5230 model: &Arc<dyn LanguageModel>,
5231 thread: &Entity<Thread>,
5232 cx: &mut TestAppContext,
5233 ) {
5234 thread.update(cx, |thread, cx| {
5235 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5236 thread.send_to_model(
5237 model.clone(),
5238 CompletionIntent::ThreadSummarization,
5239 None,
5240 cx,
5241 );
5242 });
5243
5244 let fake_model = model.as_fake();
5245 simulate_successful_response(fake_model, cx);
5246
5247 thread.read_with(cx, |thread, _| {
5248 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5249 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5250 });
5251
5252 // Simulate summary request ending
5253 cx.run_until_parked();
5254 fake_model.end_last_completion_stream();
5255 cx.run_until_parked();
5256
5257 // State is set to Error and default message
5258 thread.read_with(cx, |thread, _| {
5259 assert!(matches!(thread.summary(), ThreadSummary::Error));
5260 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5261 });
5262 }
5263
5264 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5265 cx.run_until_parked();
5266 fake_model.send_last_completion_stream_text_chunk("Assistant response");
5267 fake_model.end_last_completion_stream();
5268 cx.run_until_parked();
5269 }
5270
5271 fn init_test_settings(cx: &mut TestAppContext) -> Arc<dyn Fs> {
5272 let fs = FakeFs::new(cx.executor());
5273 cx.update(|cx| {
5274 let settings_store = SettingsStore::test(cx);
5275 cx.set_global(settings_store);
5276 language::init(cx);
5277 Project::init_settings(cx);
5278 AgentSettings::register(cx);
5279 prompt_store::init(cx);
5280 thread_store::init(fs.clone(), cx);
5281 workspace::init_settings(cx);
5282 language_model::init_settings(cx);
5283 theme::init(theme::LoadThemes::JustBase, cx);
5284 ToolRegistry::default_global(cx);
5285 assistant_tool::init(cx);
5286
5287 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5288 http_client::FakeHttpClient::with_200_response(),
5289 "http://localhost".to_string(),
5290 None,
5291 ));
5292 assistant_tools::init(http_client, cx);
5293 });
5294 fs
5295 }
5296
5297 // Helper to create a test project with test files
5298 async fn create_test_project(
5299 fs: &Arc<dyn Fs>,
5300 cx: &mut TestAppContext,
5301 files: serde_json::Value,
5302 ) -> Entity<Project> {
5303 fs.as_fake().insert_tree(path!("/test"), files).await;
5304 Project::test(fs.clone(), [path!("/test").as_ref()], cx).await
5305 }
5306
5307 async fn setup_test_environment(
5308 cx: &mut TestAppContext,
5309 project: Entity<Project>,
5310 ) -> (
5311 Entity<Workspace>,
5312 Entity<ThreadStore>,
5313 Entity<Thread>,
5314 Entity<ContextStore>,
5315 Arc<dyn LanguageModel>,
5316 ) {
5317 let (workspace, cx) =
5318 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5319
5320 let thread_store = cx
5321 .update(|_, cx| {
5322 ThreadStore::load(
5323 project.clone(),
5324 cx.new(|_| ToolWorkingSet::default()),
5325 None,
5326 Arc::new(PromptBuilder::new(None).unwrap()),
5327 cx,
5328 )
5329 })
5330 .await
5331 .unwrap();
5332
5333 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5334 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5335
5336 let provider = Arc::new(FakeLanguageModelProvider::default());
5337 let model = provider.test_model();
5338 let model: Arc<dyn LanguageModel> = Arc::new(model);
5339
5340 cx.update(|_, cx| {
5341 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5342 registry.set_default_model(
5343 Some(ConfiguredModel {
5344 provider: provider.clone(),
5345 model: model.clone(),
5346 }),
5347 cx,
5348 );
5349 registry.set_thread_summary_model(
5350 Some(ConfiguredModel {
5351 provider,
5352 model: model.clone(),
5353 }),
5354 cx,
5355 );
5356 })
5357 });
5358
5359 (workspace, thread_store, thread, context_store, model)
5360 }
5361
5362 async fn add_file_to_context(
5363 project: &Entity<Project>,
5364 context_store: &Entity<ContextStore>,
5365 path: &str,
5366 cx: &mut TestAppContext,
5367 ) -> Result<Entity<language::Buffer>> {
5368 let buffer_path = project
5369 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5370 .unwrap();
5371
5372 let buffer = project
5373 .update(cx, |project, cx| {
5374 project.open_buffer(buffer_path.clone(), cx)
5375 })
5376 .await
5377 .unwrap();
5378
5379 context_store.update(cx, |context_store, cx| {
5380 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5381 });
5382
5383 Ok(buffer)
5384 }
5385}