1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use action_log::ActionLog;
12use agent_settings::{
13 AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
14 SUMMARIZE_THREAD_PROMPT,
15};
16use anyhow::{Result, anyhow};
17use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
18use chrono::{DateTime, Utc};
19use client::{ModelRequestUsage, RequestUsage};
20use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
21use collections::HashMap;
22use futures::{FutureExt, StreamExt as _, future::Shared};
23use git::repository::DiffType;
24use gpui::{
25 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
26 WeakEntity, Window,
27};
28use http_client::StatusCode;
29use language_model::{
30 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
31 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
32 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
33 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
34 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
35 TokenUsage,
36};
37use postage::stream::Stream as _;
38use project::{
39 Project,
40 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
41};
42use prompt_store::{ModelContext, PromptBuilder};
43use schemars::JsonSchema;
44use serde::{Deserialize, Serialize};
45use settings::Settings;
46use std::{
47 io::Write,
48 ops::Range,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use thiserror::Error;
53use util::{ResultExt as _, post_inc};
54use uuid::Uuid;
55
56const MAX_RETRY_ATTEMPTS: u8 = 4;
57const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
58
59#[derive(Debug, Clone)]
60enum RetryStrategy {
61 ExponentialBackoff {
62 initial_delay: Duration,
63 max_attempts: u8,
64 },
65 Fixed {
66 delay: Duration,
67 max_attempts: u8,
68 },
69}
70
71#[derive(
72 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
73)]
74pub struct ThreadId(Arc<str>);
75
76impl ThreadId {
77 pub fn new() -> Self {
78 Self(Uuid::new_v4().to_string().into())
79 }
80}
81
82impl std::fmt::Display for ThreadId {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.0)
85 }
86}
87
88impl From<&str> for ThreadId {
89 fn from(value: &str) -> Self {
90 Self(value.into())
91 }
92}
93
94/// The ID of the user prompt that initiated a request.
95///
96/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
98pub struct PromptId(Arc<str>);
99
100impl PromptId {
101 pub fn new() -> Self {
102 Self(Uuid::new_v4().to_string().into())
103 }
104}
105
106impl std::fmt::Display for PromptId {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "{}", self.0)
109 }
110}
111
112#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
113pub struct MessageId(pub usize);
114
115impl MessageId {
116 fn post_inc(&mut self) -> Self {
117 Self(post_inc(&mut self.0))
118 }
119
120 pub fn as_usize(&self) -> usize {
121 self.0
122 }
123}
124
125/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
126#[derive(Clone, Debug)]
127pub struct MessageCrease {
128 pub range: Range<usize>,
129 pub icon_path: SharedString,
130 pub label: SharedString,
131 /// None for a deserialized message, Some otherwise.
132 pub context: Option<AgentContextHandle>,
133}
134
135/// A message in a [`Thread`].
136#[derive(Debug, Clone)]
137pub struct Message {
138 pub id: MessageId,
139 pub role: Role,
140 pub segments: Vec<MessageSegment>,
141 pub loaded_context: LoadedContext,
142 pub creases: Vec<MessageCrease>,
143 pub is_hidden: bool,
144 pub ui_only: bool,
145}
146
147impl Message {
148 /// Returns whether the message contains any meaningful text that should be displayed
149 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
150 pub fn should_display_content(&self) -> bool {
151 self.segments.iter().all(|segment| segment.should_display())
152 }
153
154 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
155 if let Some(MessageSegment::Thinking {
156 text: segment,
157 signature: current_signature,
158 }) = self.segments.last_mut()
159 {
160 if let Some(signature) = signature {
161 *current_signature = Some(signature);
162 }
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Thinking {
166 text: text.to_string(),
167 signature,
168 });
169 }
170 }
171
172 pub fn push_redacted_thinking(&mut self, data: String) {
173 self.segments.push(MessageSegment::RedactedThinking(data));
174 }
175
176 pub fn push_text(&mut self, text: &str) {
177 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
178 segment.push_str(text);
179 } else {
180 self.segments.push(MessageSegment::Text(text.to_string()));
181 }
182 }
183
184 pub fn to_message_content(&self) -> String {
185 let mut result = String::new();
186
187 if !self.loaded_context.text.is_empty() {
188 result.push_str(&self.loaded_context.text);
189 }
190
191 for segment in &self.segments {
192 match segment {
193 MessageSegment::Text(text) => result.push_str(text),
194 MessageSegment::Thinking { text, .. } => {
195 result.push_str("<think>\n");
196 result.push_str(text);
197 result.push_str("\n</think>");
198 }
199 MessageSegment::RedactedThinking(_) => {}
200 }
201 }
202
203 result
204 }
205}
206
207#[derive(Debug, Clone, PartialEq, Eq)]
208pub enum MessageSegment {
209 Text(String),
210 Thinking {
211 text: String,
212 signature: Option<String>,
213 },
214 RedactedThinking(String),
215}
216
217impl MessageSegment {
218 pub fn should_display(&self) -> bool {
219 match self {
220 Self::Text(text) => text.is_empty(),
221 Self::Thinking { text, .. } => text.is_empty(),
222 Self::RedactedThinking(_) => false,
223 }
224 }
225
226 pub fn text(&self) -> Option<&str> {
227 match self {
228 MessageSegment::Text(text) => Some(text),
229 _ => None,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
235pub struct ProjectSnapshot {
236 pub worktree_snapshots: Vec<WorktreeSnapshot>,
237 pub timestamp: DateTime<Utc>,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
241pub struct WorktreeSnapshot {
242 pub worktree_path: String,
243 pub git_state: Option<GitState>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
247pub struct GitState {
248 pub remote_url: Option<String>,
249 pub head_sha: Option<String>,
250 pub current_branch: Option<String>,
251 pub diff: Option<String>,
252}
253
254#[derive(Clone, Debug)]
255pub struct ThreadCheckpoint {
256 message_id: MessageId,
257 git_checkpoint: GitStoreCheckpoint,
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
261pub enum ThreadFeedback {
262 Positive,
263 Negative,
264}
265
266pub enum LastRestoreCheckpoint {
267 Pending {
268 message_id: MessageId,
269 },
270 Error {
271 message_id: MessageId,
272 error: String,
273 },
274}
275
276impl LastRestoreCheckpoint {
277 pub fn message_id(&self) -> MessageId {
278 match self {
279 LastRestoreCheckpoint::Pending { message_id } => *message_id,
280 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
281 }
282 }
283}
284
285#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
286pub enum DetailedSummaryState {
287 #[default]
288 NotGenerated,
289 Generating {
290 message_id: MessageId,
291 },
292 Generated {
293 text: SharedString,
294 message_id: MessageId,
295 },
296}
297
298impl DetailedSummaryState {
299 fn text(&self) -> Option<SharedString> {
300 if let Self::Generated { text, .. } = self {
301 Some(text.clone())
302 } else {
303 None
304 }
305 }
306}
307
308#[derive(Default, Debug)]
309pub struct TotalTokenUsage {
310 pub total: u64,
311 pub max: u64,
312}
313
314impl TotalTokenUsage {
315 pub fn ratio(&self) -> TokenUsageRatio {
316 #[cfg(debug_assertions)]
317 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
318 .unwrap_or("0.8".to_string())
319 .parse()
320 .unwrap();
321 #[cfg(not(debug_assertions))]
322 let warning_threshold: f32 = 0.8;
323
324 // When the maximum is unknown because there is no selected model,
325 // avoid showing the token limit warning.
326 if self.max == 0 {
327 TokenUsageRatio::Normal
328 } else if self.total >= self.max {
329 TokenUsageRatio::Exceeded
330 } else if self.total as f32 / self.max as f32 >= warning_threshold {
331 TokenUsageRatio::Warning
332 } else {
333 TokenUsageRatio::Normal
334 }
335 }
336
337 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
338 TotalTokenUsage {
339 total: self.total + tokens,
340 max: self.max,
341 }
342 }
343}
344
345#[derive(Debug, Default, PartialEq, Eq)]
346pub enum TokenUsageRatio {
347 #[default]
348 Normal,
349 Warning,
350 Exceeded,
351}
352
353#[derive(Debug, Clone, Copy)]
354pub enum QueueState {
355 Sending,
356 Queued { position: usize },
357 Started,
358}
359
360/// A thread of conversation with the LLM.
361pub struct Thread {
362 id: ThreadId,
363 updated_at: DateTime<Utc>,
364 summary: ThreadSummary,
365 pending_summary: Task<Option<()>>,
366 detailed_summary_task: Task<Option<()>>,
367 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
368 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
369 completion_mode: agent_settings::CompletionMode,
370 messages: Vec<Message>,
371 next_message_id: MessageId,
372 last_prompt_id: PromptId,
373 project_context: SharedProjectContext,
374 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
375 completion_count: usize,
376 pending_completions: Vec<PendingCompletion>,
377 project: Entity<Project>,
378 prompt_builder: Arc<PromptBuilder>,
379 tools: Entity<ToolWorkingSet>,
380 tool_use: ToolUseState,
381 action_log: Entity<ActionLog>,
382 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
383 pending_checkpoint: Option<ThreadCheckpoint>,
384 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
385 request_token_usage: Vec<TokenUsage>,
386 cumulative_token_usage: TokenUsage,
387 exceeded_window_error: Option<ExceededWindowError>,
388 tool_use_limit_reached: bool,
389 retry_state: Option<RetryState>,
390 message_feedback: HashMap<MessageId, ThreadFeedback>,
391 last_received_chunk_at: Option<Instant>,
392 request_callback: Option<
393 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
394 >,
395 remaining_turns: u32,
396 configured_model: Option<ConfiguredModel>,
397 profile: AgentProfile,
398 last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
399}
400
401#[derive(Clone, Debug)]
402struct RetryState {
403 attempt: u8,
404 max_attempts: u8,
405 intent: CompletionIntent,
406}
407
408#[derive(Clone, Debug, PartialEq, Eq)]
409pub enum ThreadSummary {
410 Pending,
411 Generating,
412 Ready(SharedString),
413 Error,
414}
415
416impl ThreadSummary {
417 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
418
419 pub fn or_default(&self) -> SharedString {
420 self.unwrap_or(Self::DEFAULT)
421 }
422
423 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
424 self.ready().unwrap_or_else(|| message.into())
425 }
426
427 pub fn ready(&self) -> Option<SharedString> {
428 match self {
429 ThreadSummary::Ready(summary) => Some(summary.clone()),
430 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
436pub struct ExceededWindowError {
437 /// Model used when last message exceeded context window
438 model_id: LanguageModelId,
439 /// Token count including last message
440 token_count: u64,
441}
442
443impl Thread {
444 pub fn new(
445 project: Entity<Project>,
446 tools: Entity<ToolWorkingSet>,
447 prompt_builder: Arc<PromptBuilder>,
448 system_prompt: SharedProjectContext,
449 cx: &mut Context<Self>,
450 ) -> Self {
451 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
452 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
453 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
454
455 Self {
456 id: ThreadId::new(),
457 updated_at: Utc::now(),
458 summary: ThreadSummary::Pending,
459 pending_summary: Task::ready(None),
460 detailed_summary_task: Task::ready(None),
461 detailed_summary_tx,
462 detailed_summary_rx,
463 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
464 messages: Vec::new(),
465 next_message_id: MessageId(0),
466 last_prompt_id: PromptId::new(),
467 project_context: system_prompt,
468 checkpoints_by_message: HashMap::default(),
469 completion_count: 0,
470 pending_completions: Vec::new(),
471 project: project.clone(),
472 prompt_builder,
473 tools: tools.clone(),
474 last_restore_checkpoint: None,
475 pending_checkpoint: None,
476 tool_use: ToolUseState::new(tools.clone()),
477 action_log: cx.new(|_| ActionLog::new(project.clone())),
478 initial_project_snapshot: {
479 let project_snapshot = Self::project_snapshot(project, cx);
480 cx.foreground_executor()
481 .spawn(async move { Some(project_snapshot.await) })
482 .shared()
483 },
484 request_token_usage: Vec::new(),
485 cumulative_token_usage: TokenUsage::default(),
486 exceeded_window_error: None,
487 tool_use_limit_reached: false,
488 retry_state: None,
489 message_feedback: HashMap::default(),
490 last_error_context: None,
491 last_received_chunk_at: None,
492 request_callback: None,
493 remaining_turns: u32::MAX,
494 configured_model,
495 profile: AgentProfile::new(profile_id, tools),
496 }
497 }
498
499 pub fn deserialize(
500 id: ThreadId,
501 serialized: SerializedThread,
502 project: Entity<Project>,
503 tools: Entity<ToolWorkingSet>,
504 prompt_builder: Arc<PromptBuilder>,
505 project_context: SharedProjectContext,
506 window: Option<&mut Window>, // None in headless mode
507 cx: &mut Context<Self>,
508 ) -> Self {
509 let next_message_id = MessageId(
510 serialized
511 .messages
512 .last()
513 .map(|message| message.id.0 + 1)
514 .unwrap_or(0),
515 );
516 let tool_use = ToolUseState::from_serialized_messages(
517 tools.clone(),
518 &serialized.messages,
519 project.clone(),
520 window,
521 cx,
522 );
523 let (detailed_summary_tx, detailed_summary_rx) =
524 postage::watch::channel_with(serialized.detailed_summary_state);
525
526 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
527 serialized
528 .model
529 .and_then(|model| {
530 let model = SelectedModel {
531 provider: model.provider.clone().into(),
532 model: model.model.into(),
533 };
534 registry.select_model(&model, cx)
535 })
536 .or_else(|| registry.default_model())
537 });
538
539 let completion_mode = serialized
540 .completion_mode
541 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
542 let profile_id = serialized
543 .profile
544 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
545
546 Self {
547 id,
548 updated_at: serialized.updated_at,
549 summary: ThreadSummary::Ready(serialized.summary),
550 pending_summary: Task::ready(None),
551 detailed_summary_task: Task::ready(None),
552 detailed_summary_tx,
553 detailed_summary_rx,
554 completion_mode,
555 retry_state: None,
556 messages: serialized
557 .messages
558 .into_iter()
559 .map(|message| Message {
560 id: message.id,
561 role: message.role,
562 segments: message
563 .segments
564 .into_iter()
565 .map(|segment| match segment {
566 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
567 SerializedMessageSegment::Thinking { text, signature } => {
568 MessageSegment::Thinking { text, signature }
569 }
570 SerializedMessageSegment::RedactedThinking { data } => {
571 MessageSegment::RedactedThinking(data)
572 }
573 })
574 .collect(),
575 loaded_context: LoadedContext {
576 contexts: Vec::new(),
577 text: message.context,
578 images: Vec::new(),
579 },
580 creases: message
581 .creases
582 .into_iter()
583 .map(|crease| MessageCrease {
584 range: crease.start..crease.end,
585 icon_path: crease.icon_path,
586 label: crease.label,
587 context: None,
588 })
589 .collect(),
590 is_hidden: message.is_hidden,
591 ui_only: false, // UI-only messages are not persisted
592 })
593 .collect(),
594 next_message_id,
595 last_prompt_id: PromptId::new(),
596 project_context,
597 checkpoints_by_message: HashMap::default(),
598 completion_count: 0,
599 pending_completions: Vec::new(),
600 last_restore_checkpoint: None,
601 pending_checkpoint: None,
602 project: project.clone(),
603 prompt_builder,
604 tools: tools.clone(),
605 tool_use,
606 action_log: cx.new(|_| ActionLog::new(project)),
607 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
608 request_token_usage: serialized.request_token_usage,
609 cumulative_token_usage: serialized.cumulative_token_usage,
610 exceeded_window_error: None,
611 tool_use_limit_reached: serialized.tool_use_limit_reached,
612 message_feedback: HashMap::default(),
613 last_error_context: None,
614 last_received_chunk_at: None,
615 request_callback: None,
616 remaining_turns: u32::MAX,
617 configured_model,
618 profile: AgentProfile::new(profile_id, tools),
619 }
620 }
621
622 pub fn set_request_callback(
623 &mut self,
624 callback: impl 'static
625 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
626 ) {
627 self.request_callback = Some(Box::new(callback));
628 }
629
630 pub fn id(&self) -> &ThreadId {
631 &self.id
632 }
633
634 pub fn profile(&self) -> &AgentProfile {
635 &self.profile
636 }
637
638 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
639 if &id != self.profile.id() {
640 self.profile = AgentProfile::new(id, self.tools.clone());
641 cx.emit(ThreadEvent::ProfileChanged);
642 }
643 }
644
645 pub fn is_empty(&self) -> bool {
646 self.messages.is_empty()
647 }
648
649 pub fn updated_at(&self) -> DateTime<Utc> {
650 self.updated_at
651 }
652
653 pub fn touch_updated_at(&mut self) {
654 self.updated_at = Utc::now();
655 }
656
657 pub fn advance_prompt_id(&mut self) {
658 self.last_prompt_id = PromptId::new();
659 }
660
661 pub fn project_context(&self) -> SharedProjectContext {
662 self.project_context.clone()
663 }
664
665 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
666 if self.configured_model.is_none() {
667 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
668 }
669 self.configured_model.clone()
670 }
671
672 pub fn configured_model(&self) -> Option<ConfiguredModel> {
673 self.configured_model.clone()
674 }
675
676 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
677 self.configured_model = model;
678 cx.notify();
679 }
680
681 pub fn summary(&self) -> &ThreadSummary {
682 &self.summary
683 }
684
685 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
686 let current_summary = match &self.summary {
687 ThreadSummary::Pending | ThreadSummary::Generating => return,
688 ThreadSummary::Ready(summary) => summary,
689 ThreadSummary::Error => &ThreadSummary::DEFAULT,
690 };
691
692 let mut new_summary = new_summary.into();
693
694 if new_summary.is_empty() {
695 new_summary = ThreadSummary::DEFAULT;
696 }
697
698 if current_summary != &new_summary {
699 self.summary = ThreadSummary::Ready(new_summary);
700 cx.emit(ThreadEvent::SummaryChanged);
701 }
702 }
703
704 pub fn completion_mode(&self) -> CompletionMode {
705 self.completion_mode
706 }
707
708 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
709 self.completion_mode = mode;
710 }
711
712 pub fn message(&self, id: MessageId) -> Option<&Message> {
713 let index = self
714 .messages
715 .binary_search_by(|message| message.id.cmp(&id))
716 .ok()?;
717
718 self.messages.get(index)
719 }
720
721 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
722 self.messages.iter()
723 }
724
725 pub fn is_generating(&self) -> bool {
726 !self.pending_completions.is_empty() || !self.all_tools_finished()
727 }
728
729 /// Indicates whether streaming of language model events is stale.
730 /// When `is_generating()` is false, this method returns `None`.
731 pub fn is_generation_stale(&self) -> Option<bool> {
732 const STALE_THRESHOLD: u128 = 250;
733
734 self.last_received_chunk_at
735 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
736 }
737
738 fn received_chunk(&mut self) {
739 self.last_received_chunk_at = Some(Instant::now());
740 }
741
742 pub fn queue_state(&self) -> Option<QueueState> {
743 self.pending_completions
744 .first()
745 .map(|pending_completion| pending_completion.queue_state)
746 }
747
748 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
749 &self.tools
750 }
751
752 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
753 self.tool_use
754 .pending_tool_uses()
755 .into_iter()
756 .find(|tool_use| &tool_use.id == id)
757 }
758
759 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
760 self.tool_use
761 .pending_tool_uses()
762 .into_iter()
763 .filter(|tool_use| tool_use.status.needs_confirmation())
764 }
765
766 pub fn has_pending_tool_uses(&self) -> bool {
767 !self.tool_use.pending_tool_uses().is_empty()
768 }
769
770 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
771 self.checkpoints_by_message.get(&id).cloned()
772 }
773
774 pub fn restore_checkpoint(
775 &mut self,
776 checkpoint: ThreadCheckpoint,
777 cx: &mut Context<Self>,
778 ) -> Task<Result<()>> {
779 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
780 message_id: checkpoint.message_id,
781 });
782 cx.emit(ThreadEvent::CheckpointChanged);
783 cx.notify();
784
785 let git_store = self.project().read(cx).git_store().clone();
786 let restore = git_store.update(cx, |git_store, cx| {
787 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
788 });
789
790 cx.spawn(async move |this, cx| {
791 let result = restore.await;
792 this.update(cx, |this, cx| {
793 if let Err(err) = result.as_ref() {
794 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
795 message_id: checkpoint.message_id,
796 error: err.to_string(),
797 });
798 } else {
799 this.truncate(checkpoint.message_id, cx);
800 this.last_restore_checkpoint = None;
801 }
802 this.pending_checkpoint = None;
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 })?;
806 result
807 })
808 }
809
810 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
811 let pending_checkpoint = if self.is_generating() {
812 return;
813 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
814 checkpoint
815 } else {
816 return;
817 };
818
819 self.finalize_checkpoint(pending_checkpoint, cx);
820 }
821
822 fn finalize_checkpoint(
823 &mut self,
824 pending_checkpoint: ThreadCheckpoint,
825 cx: &mut Context<Self>,
826 ) {
827 let git_store = self.project.read(cx).git_store().clone();
828 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
829 cx.spawn(async move |this, cx| match final_checkpoint.await {
830 Ok(final_checkpoint) => {
831 let equal = git_store
832 .update(cx, |store, cx| {
833 store.compare_checkpoints(
834 pending_checkpoint.git_checkpoint.clone(),
835 final_checkpoint.clone(),
836 cx,
837 )
838 })?
839 .await
840 .unwrap_or(false);
841
842 this.update(cx, |this, cx| {
843 this.pending_checkpoint = if equal {
844 Some(pending_checkpoint)
845 } else {
846 this.insert_checkpoint(pending_checkpoint, cx);
847 Some(ThreadCheckpoint {
848 message_id: this.next_message_id,
849 git_checkpoint: final_checkpoint,
850 })
851 }
852 })?;
853
854 Ok(())
855 }
856 Err(_) => this.update(cx, |this, cx| {
857 this.insert_checkpoint(pending_checkpoint, cx)
858 }),
859 })
860 .detach();
861 }
862
863 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
864 self.checkpoints_by_message
865 .insert(checkpoint.message_id, checkpoint);
866 cx.emit(ThreadEvent::CheckpointChanged);
867 cx.notify();
868 }
869
870 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
871 self.last_restore_checkpoint.as_ref()
872 }
873
874 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
875 let Some(message_ix) = self
876 .messages
877 .iter()
878 .rposition(|message| message.id == message_id)
879 else {
880 return;
881 };
882 for deleted_message in self.messages.drain(message_ix..) {
883 self.checkpoints_by_message.remove(&deleted_message.id);
884 }
885 cx.notify();
886 }
887
888 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
889 self.messages
890 .iter()
891 .find(|message| message.id == id)
892 .into_iter()
893 .flat_map(|message| message.loaded_context.contexts.iter())
894 }
895
896 pub fn is_turn_end(&self, ix: usize) -> bool {
897 if self.messages.is_empty() {
898 return false;
899 }
900
901 if !self.is_generating() && ix == self.messages.len() - 1 {
902 return true;
903 }
904
905 let Some(message) = self.messages.get(ix) else {
906 return false;
907 };
908
909 if message.role != Role::Assistant {
910 return false;
911 }
912
913 self.messages
914 .get(ix + 1)
915 .and_then(|message| {
916 self.message(message.id)
917 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
918 })
919 .unwrap_or(false)
920 }
921
922 pub fn tool_use_limit_reached(&self) -> bool {
923 self.tool_use_limit_reached
924 }
925
926 /// Returns whether all of the tool uses have finished running.
927 pub fn all_tools_finished(&self) -> bool {
928 // If the only pending tool uses left are the ones with errors, then
929 // that means that we've finished running all of the pending tools.
930 self.tool_use
931 .pending_tool_uses()
932 .iter()
933 .all(|pending_tool_use| pending_tool_use.status.is_error())
934 }
935
936 /// Returns whether any pending tool uses may perform edits
937 pub fn has_pending_edit_tool_uses(&self) -> bool {
938 self.tool_use
939 .pending_tool_uses()
940 .iter()
941 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
942 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
943 }
944
945 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
946 self.tool_use.tool_uses_for_message(id, &self.project, cx)
947 }
948
949 pub fn tool_results_for_message(
950 &self,
951 assistant_message_id: MessageId,
952 ) -> Vec<&LanguageModelToolResult> {
953 self.tool_use.tool_results_for_message(assistant_message_id)
954 }
955
956 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
957 self.tool_use.tool_result(id)
958 }
959
960 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
961 match &self.tool_use.tool_result(id)?.content {
962 LanguageModelToolResultContent::Text(text) => Some(text),
963 LanguageModelToolResultContent::Image(_) => {
964 // TODO: We should display image
965 None
966 }
967 }
968 }
969
970 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
971 self.tool_use.tool_result_card(id).cloned()
972 }
973
974 /// Return tools that are both enabled and supported by the model
975 pub fn available_tools(
976 &self,
977 cx: &App,
978 model: Arc<dyn LanguageModel>,
979 ) -> Vec<LanguageModelRequestTool> {
980 if model.supports_tools() {
981 self.profile
982 .enabled_tools(cx)
983 .into_iter()
984 .filter_map(|(name, tool)| {
985 // Skip tools that cannot be supported
986 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
987 Some(LanguageModelRequestTool {
988 name: name.into(),
989 description: tool.description(),
990 input_schema,
991 })
992 })
993 .collect()
994 } else {
995 Vec::default()
996 }
997 }
998
999 pub fn insert_user_message(
1000 &mut self,
1001 text: impl Into<String>,
1002 loaded_context: ContextLoadResult,
1003 git_checkpoint: Option<GitStoreCheckpoint>,
1004 creases: Vec<MessageCrease>,
1005 cx: &mut Context<Self>,
1006 ) -> MessageId {
1007 if !loaded_context.referenced_buffers.is_empty() {
1008 self.action_log.update(cx, |log, cx| {
1009 for buffer in loaded_context.referenced_buffers {
1010 log.buffer_read(buffer, cx);
1011 }
1012 });
1013 }
1014
1015 let message_id = self.insert_message(
1016 Role::User,
1017 vec![MessageSegment::Text(text.into())],
1018 loaded_context.loaded_context,
1019 creases,
1020 false,
1021 cx,
1022 );
1023
1024 if let Some(git_checkpoint) = git_checkpoint {
1025 self.pending_checkpoint = Some(ThreadCheckpoint {
1026 message_id,
1027 git_checkpoint,
1028 });
1029 }
1030
1031 message_id
1032 }
1033
1034 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1035 let id = self.insert_message(
1036 Role::User,
1037 vec![MessageSegment::Text("Continue where you left off".into())],
1038 LoadedContext::default(),
1039 vec![],
1040 true,
1041 cx,
1042 );
1043 self.pending_checkpoint = None;
1044
1045 id
1046 }
1047
1048 pub fn insert_assistant_message(
1049 &mut self,
1050 segments: Vec<MessageSegment>,
1051 cx: &mut Context<Self>,
1052 ) -> MessageId {
1053 self.insert_message(
1054 Role::Assistant,
1055 segments,
1056 LoadedContext::default(),
1057 Vec::new(),
1058 false,
1059 cx,
1060 )
1061 }
1062
1063 pub fn insert_message(
1064 &mut self,
1065 role: Role,
1066 segments: Vec<MessageSegment>,
1067 loaded_context: LoadedContext,
1068 creases: Vec<MessageCrease>,
1069 is_hidden: bool,
1070 cx: &mut Context<Self>,
1071 ) -> MessageId {
1072 let id = self.next_message_id.post_inc();
1073 self.messages.push(Message {
1074 id,
1075 role,
1076 segments,
1077 loaded_context,
1078 creases,
1079 is_hidden,
1080 ui_only: false,
1081 });
1082 self.touch_updated_at();
1083 cx.emit(ThreadEvent::MessageAdded(id));
1084 id
1085 }
1086
1087 pub fn edit_message(
1088 &mut self,
1089 id: MessageId,
1090 new_role: Role,
1091 new_segments: Vec<MessageSegment>,
1092 creases: Vec<MessageCrease>,
1093 loaded_context: Option<LoadedContext>,
1094 checkpoint: Option<GitStoreCheckpoint>,
1095 cx: &mut Context<Self>,
1096 ) -> bool {
1097 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1098 return false;
1099 };
1100 message.role = new_role;
1101 message.segments = new_segments;
1102 message.creases = creases;
1103 if let Some(context) = loaded_context {
1104 message.loaded_context = context;
1105 }
1106 if let Some(git_checkpoint) = checkpoint {
1107 self.checkpoints_by_message.insert(
1108 id,
1109 ThreadCheckpoint {
1110 message_id: id,
1111 git_checkpoint,
1112 },
1113 );
1114 }
1115 self.touch_updated_at();
1116 cx.emit(ThreadEvent::MessageEdited(id));
1117 true
1118 }
1119
1120 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1121 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1122 return false;
1123 };
1124 self.messages.remove(index);
1125 self.touch_updated_at();
1126 cx.emit(ThreadEvent::MessageDeleted(id));
1127 true
1128 }
1129
1130 /// Returns the representation of this [`Thread`] in a textual form.
1131 ///
1132 /// This is the representation we use when attaching a thread as context to another thread.
1133 pub fn text(&self) -> String {
1134 let mut text = String::new();
1135
1136 for message in &self.messages {
1137 text.push_str(match message.role {
1138 language_model::Role::User => "User:",
1139 language_model::Role::Assistant => "Agent:",
1140 language_model::Role::System => "System:",
1141 });
1142 text.push('\n');
1143
1144 for segment in &message.segments {
1145 match segment {
1146 MessageSegment::Text(content) => text.push_str(content),
1147 MessageSegment::Thinking { text: content, .. } => {
1148 text.push_str(&format!("<think>{}</think>", content))
1149 }
1150 MessageSegment::RedactedThinking(_) => {}
1151 }
1152 }
1153 text.push('\n');
1154 }
1155
1156 text
1157 }
1158
1159 /// Serializes this thread into a format for storage or telemetry.
1160 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1161 let initial_project_snapshot = self.initial_project_snapshot.clone();
1162 cx.spawn(async move |this, cx| {
1163 let initial_project_snapshot = initial_project_snapshot.await;
1164 this.read_with(cx, |this, cx| SerializedThread {
1165 version: SerializedThread::VERSION.to_string(),
1166 summary: this.summary().or_default(),
1167 updated_at: this.updated_at(),
1168 messages: this
1169 .messages()
1170 .filter(|message| !message.ui_only)
1171 .map(|message| SerializedMessage {
1172 id: message.id,
1173 role: message.role,
1174 segments: message
1175 .segments
1176 .iter()
1177 .map(|segment| match segment {
1178 MessageSegment::Text(text) => {
1179 SerializedMessageSegment::Text { text: text.clone() }
1180 }
1181 MessageSegment::Thinking { text, signature } => {
1182 SerializedMessageSegment::Thinking {
1183 text: text.clone(),
1184 signature: signature.clone(),
1185 }
1186 }
1187 MessageSegment::RedactedThinking(data) => {
1188 SerializedMessageSegment::RedactedThinking {
1189 data: data.clone(),
1190 }
1191 }
1192 })
1193 .collect(),
1194 tool_uses: this
1195 .tool_uses_for_message(message.id, cx)
1196 .into_iter()
1197 .map(|tool_use| SerializedToolUse {
1198 id: tool_use.id,
1199 name: tool_use.name,
1200 input: tool_use.input,
1201 })
1202 .collect(),
1203 tool_results: this
1204 .tool_results_for_message(message.id)
1205 .into_iter()
1206 .map(|tool_result| SerializedToolResult {
1207 tool_use_id: tool_result.tool_use_id.clone(),
1208 is_error: tool_result.is_error,
1209 content: tool_result.content.clone(),
1210 output: tool_result.output.clone(),
1211 })
1212 .collect(),
1213 context: message.loaded_context.text.clone(),
1214 creases: message
1215 .creases
1216 .iter()
1217 .map(|crease| SerializedCrease {
1218 start: crease.range.start,
1219 end: crease.range.end,
1220 icon_path: crease.icon_path.clone(),
1221 label: crease.label.clone(),
1222 })
1223 .collect(),
1224 is_hidden: message.is_hidden,
1225 })
1226 .collect(),
1227 initial_project_snapshot,
1228 cumulative_token_usage: this.cumulative_token_usage,
1229 request_token_usage: this.request_token_usage.clone(),
1230 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1231 exceeded_window_error: this.exceeded_window_error.clone(),
1232 model: this
1233 .configured_model
1234 .as_ref()
1235 .map(|model| SerializedLanguageModel {
1236 provider: model.provider.id().0.to_string(),
1237 model: model.model.id().0.to_string(),
1238 }),
1239 completion_mode: Some(this.completion_mode),
1240 tool_use_limit_reached: this.tool_use_limit_reached,
1241 profile: Some(this.profile.id().clone()),
1242 })
1243 })
1244 }
1245
1246 pub fn remaining_turns(&self) -> u32 {
1247 self.remaining_turns
1248 }
1249
1250 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1251 self.remaining_turns = remaining_turns;
1252 }
1253
1254 pub fn send_to_model(
1255 &mut self,
1256 model: Arc<dyn LanguageModel>,
1257 intent: CompletionIntent,
1258 window: Option<AnyWindowHandle>,
1259 cx: &mut Context<Self>,
1260 ) {
1261 if self.remaining_turns == 0 {
1262 return;
1263 }
1264
1265 self.remaining_turns -= 1;
1266
1267 self.flush_notifications(model.clone(), intent, cx);
1268
1269 let _checkpoint = self.finalize_pending_checkpoint(cx);
1270 self.stream_completion(
1271 self.to_completion_request(model.clone(), intent, cx),
1272 model,
1273 intent,
1274 window,
1275 cx,
1276 );
1277 }
1278
1279 pub fn to_completion_request(
1280 &self,
1281 model: Arc<dyn LanguageModel>,
1282 intent: CompletionIntent,
1283 cx: &mut Context<Self>,
1284 ) -> LanguageModelRequest {
1285 let mut request = LanguageModelRequest {
1286 thread_id: Some(self.id.to_string()),
1287 prompt_id: Some(self.last_prompt_id.to_string()),
1288 intent: Some(intent),
1289 mode: None,
1290 messages: vec![],
1291 tools: Vec::new(),
1292 tool_choice: None,
1293 stop: Vec::new(),
1294 temperature: AgentSettings::temperature_for_model(&model, cx),
1295 thinking_allowed: true,
1296 };
1297
1298 let available_tools = self.available_tools(cx, model.clone());
1299 let available_tool_names = available_tools
1300 .iter()
1301 .map(|tool| tool.name.clone())
1302 .collect();
1303
1304 let model_context = &ModelContext {
1305 available_tools: available_tool_names,
1306 };
1307
1308 if let Some(project_context) = self.project_context.borrow().as_ref() {
1309 match self
1310 .prompt_builder
1311 .generate_assistant_system_prompt(project_context, model_context)
1312 {
1313 Err(err) => {
1314 let message = format!("{err:?}").into();
1315 log::error!("{message}");
1316 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1317 header: "Error generating system prompt".into(),
1318 message,
1319 }));
1320 }
1321 Ok(system_prompt) => {
1322 request.messages.push(LanguageModelRequestMessage {
1323 role: Role::System,
1324 content: vec![MessageContent::Text(system_prompt)],
1325 cache: true,
1326 });
1327 }
1328 }
1329 } else {
1330 let message = "Context for system prompt unexpectedly not ready.".into();
1331 log::error!("{message}");
1332 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1333 header: "Error generating system prompt".into(),
1334 message,
1335 }));
1336 }
1337
1338 let mut message_ix_to_cache = None;
1339 for message in &self.messages {
1340 // ui_only messages are for the UI only, not for the model
1341 if message.ui_only {
1342 continue;
1343 }
1344
1345 let mut request_message = LanguageModelRequestMessage {
1346 role: message.role,
1347 content: Vec::new(),
1348 cache: false,
1349 };
1350
1351 message
1352 .loaded_context
1353 .add_to_request_message(&mut request_message);
1354
1355 for segment in &message.segments {
1356 match segment {
1357 MessageSegment::Text(text) => {
1358 let text = text.trim_end();
1359 if !text.is_empty() {
1360 request_message
1361 .content
1362 .push(MessageContent::Text(text.into()));
1363 }
1364 }
1365 MessageSegment::Thinking { text, signature } => {
1366 if !text.is_empty() {
1367 request_message.content.push(MessageContent::Thinking {
1368 text: text.into(),
1369 signature: signature.clone(),
1370 });
1371 }
1372 }
1373 MessageSegment::RedactedThinking(data) => {
1374 request_message
1375 .content
1376 .push(MessageContent::RedactedThinking(data.clone()));
1377 }
1378 };
1379 }
1380
1381 let mut cache_message = true;
1382 let mut tool_results_message = LanguageModelRequestMessage {
1383 role: Role::User,
1384 content: Vec::new(),
1385 cache: false,
1386 };
1387 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1388 if let Some(tool_result) = tool_result {
1389 request_message
1390 .content
1391 .push(MessageContent::ToolUse(tool_use.clone()));
1392 tool_results_message
1393 .content
1394 .push(MessageContent::ToolResult(LanguageModelToolResult {
1395 tool_use_id: tool_use.id.clone(),
1396 tool_name: tool_result.tool_name.clone(),
1397 is_error: tool_result.is_error,
1398 content: if tool_result.content.is_empty() {
1399 // Surprisingly, the API fails if we return an empty string here.
1400 // It thinks we are sending a tool use without a tool result.
1401 "<Tool returned an empty string>".into()
1402 } else {
1403 tool_result.content.clone()
1404 },
1405 output: None,
1406 }));
1407 } else {
1408 cache_message = false;
1409 log::debug!(
1410 "skipped tool use {:?} because it is still pending",
1411 tool_use
1412 );
1413 }
1414 }
1415
1416 if cache_message {
1417 message_ix_to_cache = Some(request.messages.len());
1418 }
1419 request.messages.push(request_message);
1420
1421 if !tool_results_message.content.is_empty() {
1422 if cache_message {
1423 message_ix_to_cache = Some(request.messages.len());
1424 }
1425 request.messages.push(tool_results_message);
1426 }
1427 }
1428
1429 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1430 if let Some(message_ix_to_cache) = message_ix_to_cache {
1431 request.messages[message_ix_to_cache].cache = true;
1432 }
1433
1434 request.tools = available_tools;
1435 request.mode = if model.supports_burn_mode() {
1436 Some(self.completion_mode.into())
1437 } else {
1438 Some(CompletionMode::Normal.into())
1439 };
1440
1441 request
1442 }
1443
1444 fn to_summarize_request(
1445 &self,
1446 model: &Arc<dyn LanguageModel>,
1447 intent: CompletionIntent,
1448 added_user_message: String,
1449 cx: &App,
1450 ) -> LanguageModelRequest {
1451 let mut request = LanguageModelRequest {
1452 thread_id: None,
1453 prompt_id: None,
1454 intent: Some(intent),
1455 mode: None,
1456 messages: vec![],
1457 tools: Vec::new(),
1458 tool_choice: None,
1459 stop: Vec::new(),
1460 temperature: AgentSettings::temperature_for_model(model, cx),
1461 thinking_allowed: false,
1462 };
1463
1464 for message in &self.messages {
1465 let mut request_message = LanguageModelRequestMessage {
1466 role: message.role,
1467 content: Vec::new(),
1468 cache: false,
1469 };
1470
1471 for segment in &message.segments {
1472 match segment {
1473 MessageSegment::Text(text) => request_message
1474 .content
1475 .push(MessageContent::Text(text.clone())),
1476 MessageSegment::Thinking { .. } => {}
1477 MessageSegment::RedactedThinking(_) => {}
1478 }
1479 }
1480
1481 if request_message.content.is_empty() {
1482 continue;
1483 }
1484
1485 request.messages.push(request_message);
1486 }
1487
1488 request.messages.push(LanguageModelRequestMessage {
1489 role: Role::User,
1490 content: vec![MessageContent::Text(added_user_message)],
1491 cache: false,
1492 });
1493
1494 request
1495 }
1496
1497 /// Insert auto-generated notifications (if any) to the thread
1498 fn flush_notifications(
1499 &mut self,
1500 model: Arc<dyn LanguageModel>,
1501 intent: CompletionIntent,
1502 cx: &mut Context<Self>,
1503 ) {
1504 match intent {
1505 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1506 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1507 cx.emit(ThreadEvent::ToolFinished {
1508 tool_use_id: pending_tool_use.id.clone(),
1509 pending_tool_use: Some(pending_tool_use),
1510 });
1511 }
1512 }
1513 CompletionIntent::ThreadSummarization
1514 | CompletionIntent::ThreadContextSummarization
1515 | CompletionIntent::CreateFile
1516 | CompletionIntent::EditFile
1517 | CompletionIntent::InlineAssist
1518 | CompletionIntent::TerminalInlineAssist
1519 | CompletionIntent::GenerateGitCommitMessage => {}
1520 };
1521 }
1522
1523 fn attach_tracked_files_state(
1524 &mut self,
1525 model: Arc<dyn LanguageModel>,
1526 cx: &mut App,
1527 ) -> Option<PendingToolUse> {
1528 // Represent notification as a simulated `project_notifications` tool call
1529 let tool_name = Arc::from("project_notifications");
1530 let tool = self.tools.read(cx).tool(&tool_name, cx)?;
1531
1532 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1533 return None;
1534 }
1535
1536 if self
1537 .action_log
1538 .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
1539 {
1540 return None;
1541 }
1542
1543 let input = serde_json::json!({});
1544 let request = Arc::new(LanguageModelRequest::default()); // unused
1545 let window = None;
1546 let tool_result = tool.run(
1547 input,
1548 request,
1549 self.project.clone(),
1550 self.action_log.clone(),
1551 model.clone(),
1552 window,
1553 cx,
1554 );
1555
1556 let tool_use_id =
1557 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1558
1559 let tool_use = LanguageModelToolUse {
1560 id: tool_use_id.clone(),
1561 name: tool_name.clone(),
1562 raw_input: "{}".to_string(),
1563 input: serde_json::json!({}),
1564 is_input_complete: true,
1565 };
1566
1567 let tool_output = cx.background_executor().block(tool_result.output);
1568
1569 // Attach a project_notification tool call to the latest existing
1570 // Assistant message. We cannot create a new Assistant message
1571 // because thinking models require a `thinking` block that we
1572 // cannot mock. We cannot send a notification as a normal
1573 // (non-tool-use) User message because this distracts Agent
1574 // too much.
1575 let tool_message_id = self
1576 .messages
1577 .iter()
1578 .enumerate()
1579 .rfind(|(_, message)| message.role == Role::Assistant)
1580 .map(|(_, message)| message.id)?;
1581
1582 let tool_use_metadata = ToolUseMetadata {
1583 model: model.clone(),
1584 thread_id: self.id.clone(),
1585 prompt_id: self.last_prompt_id.clone(),
1586 };
1587
1588 self.tool_use
1589 .request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx);
1590
1591 self.tool_use.insert_tool_output(
1592 tool_use_id,
1593 tool_name,
1594 tool_output,
1595 self.configured_model.as_ref(),
1596 self.completion_mode,
1597 )
1598 }
1599
1600 pub fn stream_completion(
1601 &mut self,
1602 request: LanguageModelRequest,
1603 model: Arc<dyn LanguageModel>,
1604 intent: CompletionIntent,
1605 window: Option<AnyWindowHandle>,
1606 cx: &mut Context<Self>,
1607 ) {
1608 self.tool_use_limit_reached = false;
1609
1610 let pending_completion_id = post_inc(&mut self.completion_count);
1611 let mut request_callback_parameters = if self.request_callback.is_some() {
1612 Some((request.clone(), Vec::new()))
1613 } else {
1614 None
1615 };
1616 let prompt_id = self.last_prompt_id.clone();
1617 let tool_use_metadata = ToolUseMetadata {
1618 model: model.clone(),
1619 thread_id: self.id.clone(),
1620 prompt_id: prompt_id.clone(),
1621 };
1622
1623 let completion_mode = request
1624 .mode
1625 .unwrap_or(cloud_llm_client::CompletionMode::Normal);
1626
1627 self.last_received_chunk_at = Some(Instant::now());
1628
1629 let task = cx.spawn(async move |thread, cx| {
1630 let stream_completion_future = model.stream_completion(request, cx);
1631 let initial_token_usage =
1632 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1633 let stream_completion = async {
1634 let mut events = stream_completion_future.await?;
1635
1636 let mut stop_reason = StopReason::EndTurn;
1637 let mut current_token_usage = TokenUsage::default();
1638
1639 thread
1640 .update(cx, |_thread, cx| {
1641 cx.emit(ThreadEvent::NewRequest);
1642 })
1643 .ok();
1644
1645 let mut request_assistant_message_id = None;
1646
1647 while let Some(event) = events.next().await {
1648 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1649 response_events
1650 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1651 }
1652
1653 thread.update(cx, |thread, cx| {
1654 match event? {
1655 LanguageModelCompletionEvent::StartMessage { .. } => {
1656 request_assistant_message_id =
1657 Some(thread.insert_assistant_message(
1658 vec![MessageSegment::Text(String::new())],
1659 cx,
1660 ));
1661 }
1662 LanguageModelCompletionEvent::Stop(reason) => {
1663 stop_reason = reason;
1664 }
1665 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1666 thread.update_token_usage_at_last_message(token_usage);
1667 thread.cumulative_token_usage = thread.cumulative_token_usage
1668 + token_usage
1669 - current_token_usage;
1670 current_token_usage = token_usage;
1671 }
1672 LanguageModelCompletionEvent::Text(chunk) => {
1673 thread.received_chunk();
1674
1675 cx.emit(ThreadEvent::ReceivedTextChunk);
1676 if let Some(last_message) = thread.messages.last_mut() {
1677 if last_message.role == Role::Assistant
1678 && !thread.tool_use.has_tool_results(last_message.id)
1679 {
1680 last_message.push_text(&chunk);
1681 cx.emit(ThreadEvent::StreamedAssistantText(
1682 last_message.id,
1683 chunk,
1684 ));
1685 } else {
1686 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1687 // of a new Assistant response.
1688 //
1689 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1690 // will result in duplicating the text of the chunk in the rendered Markdown.
1691 request_assistant_message_id =
1692 Some(thread.insert_assistant_message(
1693 vec![MessageSegment::Text(chunk.to_string())],
1694 cx,
1695 ));
1696 };
1697 }
1698 }
1699 LanguageModelCompletionEvent::Thinking {
1700 text: chunk,
1701 signature,
1702 } => {
1703 thread.received_chunk();
1704
1705 if let Some(last_message) = thread.messages.last_mut() {
1706 if last_message.role == Role::Assistant
1707 && !thread.tool_use.has_tool_results(last_message.id)
1708 {
1709 last_message.push_thinking(&chunk, signature);
1710 cx.emit(ThreadEvent::StreamedAssistantThinking(
1711 last_message.id,
1712 chunk,
1713 ));
1714 } else {
1715 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1716 // of a new Assistant response.
1717 //
1718 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1719 // will result in duplicating the text of the chunk in the rendered Markdown.
1720 request_assistant_message_id =
1721 Some(thread.insert_assistant_message(
1722 vec![MessageSegment::Thinking {
1723 text: chunk.to_string(),
1724 signature,
1725 }],
1726 cx,
1727 ));
1728 };
1729 }
1730 }
1731 LanguageModelCompletionEvent::RedactedThinking { data } => {
1732 thread.received_chunk();
1733
1734 if let Some(last_message) = thread.messages.last_mut() {
1735 if last_message.role == Role::Assistant
1736 && !thread.tool_use.has_tool_results(last_message.id)
1737 {
1738 last_message.push_redacted_thinking(data);
1739 } else {
1740 request_assistant_message_id =
1741 Some(thread.insert_assistant_message(
1742 vec![MessageSegment::RedactedThinking(data)],
1743 cx,
1744 ));
1745 };
1746 }
1747 }
1748 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1749 let last_assistant_message_id = request_assistant_message_id
1750 .unwrap_or_else(|| {
1751 let new_assistant_message_id =
1752 thread.insert_assistant_message(vec![], cx);
1753 request_assistant_message_id =
1754 Some(new_assistant_message_id);
1755 new_assistant_message_id
1756 });
1757
1758 let tool_use_id = tool_use.id.clone();
1759 let streamed_input = if tool_use.is_input_complete {
1760 None
1761 } else {
1762 Some(tool_use.input.clone())
1763 };
1764
1765 let ui_text = thread.tool_use.request_tool_use(
1766 last_assistant_message_id,
1767 tool_use,
1768 tool_use_metadata.clone(),
1769 cx,
1770 );
1771
1772 if let Some(input) = streamed_input {
1773 cx.emit(ThreadEvent::StreamedToolUse {
1774 tool_use_id,
1775 ui_text,
1776 input,
1777 });
1778 }
1779 }
1780 LanguageModelCompletionEvent::ToolUseJsonParseError {
1781 id,
1782 tool_name,
1783 raw_input: invalid_input_json,
1784 json_parse_error,
1785 } => {
1786 thread.receive_invalid_tool_json(
1787 id,
1788 tool_name,
1789 invalid_input_json,
1790 json_parse_error,
1791 window,
1792 cx,
1793 );
1794 }
1795 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1796 if let Some(completion) = thread
1797 .pending_completions
1798 .iter_mut()
1799 .find(|completion| completion.id == pending_completion_id)
1800 {
1801 match status_update {
1802 CompletionRequestStatus::Queued { position } => {
1803 completion.queue_state =
1804 QueueState::Queued { position };
1805 }
1806 CompletionRequestStatus::Started => {
1807 completion.queue_state = QueueState::Started;
1808 }
1809 CompletionRequestStatus::Failed {
1810 code,
1811 message,
1812 request_id: _,
1813 retry_after,
1814 } => {
1815 return Err(
1816 LanguageModelCompletionError::from_cloud_failure(
1817 model.upstream_provider_name(),
1818 code,
1819 message,
1820 retry_after.map(Duration::from_secs_f64),
1821 ),
1822 );
1823 }
1824 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1825 thread.update_model_request_usage(
1826 amount as u32,
1827 limit,
1828 cx,
1829 );
1830 }
1831 CompletionRequestStatus::ToolUseLimitReached => {
1832 thread.tool_use_limit_reached = true;
1833 cx.emit(ThreadEvent::ToolUseLimitReached);
1834 }
1835 }
1836 }
1837 }
1838 }
1839
1840 thread.touch_updated_at();
1841 cx.emit(ThreadEvent::StreamedCompletion);
1842 cx.notify();
1843
1844 Ok(())
1845 })??;
1846
1847 smol::future::yield_now().await;
1848 }
1849
1850 thread.update(cx, |thread, cx| {
1851 thread.last_received_chunk_at = None;
1852 thread
1853 .pending_completions
1854 .retain(|completion| completion.id != pending_completion_id);
1855
1856 // If there is a response without tool use, summarize the message. Otherwise,
1857 // allow two tool uses before summarizing.
1858 if matches!(thread.summary, ThreadSummary::Pending)
1859 && thread.messages.len() >= 2
1860 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1861 {
1862 thread.summarize(cx);
1863 }
1864 })?;
1865
1866 anyhow::Ok(stop_reason)
1867 };
1868
1869 let result = stream_completion.await;
1870 let mut retry_scheduled = false;
1871
1872 thread
1873 .update(cx, |thread, cx| {
1874 thread.finalize_pending_checkpoint(cx);
1875 match result.as_ref() {
1876 Ok(stop_reason) => {
1877 match stop_reason {
1878 StopReason::ToolUse => {
1879 let tool_uses =
1880 thread.use_pending_tools(window, model.clone(), cx);
1881 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1882 }
1883 StopReason::EndTurn | StopReason::MaxTokens => {
1884 thread.project.update(cx, |project, cx| {
1885 project.set_agent_location(None, cx);
1886 });
1887 }
1888 StopReason::Refusal => {
1889 thread.project.update(cx, |project, cx| {
1890 project.set_agent_location(None, cx);
1891 });
1892
1893 // Remove the turn that was refused.
1894 //
1895 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1896 {
1897 let mut messages_to_remove = Vec::new();
1898
1899 for (ix, message) in
1900 thread.messages.iter().enumerate().rev()
1901 {
1902 messages_to_remove.push(message.id);
1903
1904 if message.role == Role::User {
1905 if ix == 0 {
1906 break;
1907 }
1908
1909 if let Some(prev_message) =
1910 thread.messages.get(ix - 1)
1911 && prev_message.role == Role::Assistant {
1912 break;
1913 }
1914 }
1915 }
1916
1917 for message_id in messages_to_remove {
1918 thread.delete_message(message_id, cx);
1919 }
1920 }
1921
1922 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1923 header: "Language model refusal".into(),
1924 message:
1925 "Model refused to generate content for safety reasons."
1926 .into(),
1927 }));
1928 }
1929 }
1930
1931 // We successfully completed, so cancel any remaining retries.
1932 thread.retry_state = None;
1933 }
1934 Err(error) => {
1935 thread.project.update(cx, |project, cx| {
1936 project.set_agent_location(None, cx);
1937 });
1938
1939 if error.is::<PaymentRequiredError>() {
1940 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1941 } else if let Some(error) =
1942 error.downcast_ref::<ModelRequestLimitReachedError>()
1943 {
1944 cx.emit(ThreadEvent::ShowError(
1945 ThreadError::ModelRequestLimitReached { plan: error.plan },
1946 ));
1947 } else if let Some(completion_error) =
1948 error.downcast_ref::<LanguageModelCompletionError>()
1949 {
1950 match &completion_error {
1951 LanguageModelCompletionError::PromptTooLarge {
1952 tokens, ..
1953 } => {
1954 let tokens = tokens.unwrap_or_else(|| {
1955 // We didn't get an exact token count from the API, so fall back on our estimate.
1956 thread
1957 .total_token_usage()
1958 .map(|usage| usage.total)
1959 .unwrap_or(0)
1960 // We know the context window was exceeded in practice, so if our estimate was
1961 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1962 .max(
1963 model
1964 .max_token_count_for_mode(completion_mode)
1965 .saturating_add(1),
1966 )
1967 });
1968 thread.exceeded_window_error = Some(ExceededWindowError {
1969 model_id: model.id(),
1970 token_count: tokens,
1971 });
1972 cx.notify();
1973 }
1974 _ => {
1975 if let Some(retry_strategy) =
1976 Thread::get_retry_strategy(completion_error)
1977 {
1978 log::info!(
1979 "Retrying with {:?} for language model completion error {:?}",
1980 retry_strategy,
1981 completion_error
1982 );
1983
1984 retry_scheduled = thread
1985 .handle_retryable_error_with_delay(
1986 completion_error,
1987 Some(retry_strategy),
1988 model.clone(),
1989 intent,
1990 window,
1991 cx,
1992 );
1993 }
1994 }
1995 }
1996 }
1997
1998 if !retry_scheduled {
1999 thread.cancel_last_completion(window, cx);
2000 }
2001 }
2002 }
2003
2004 if !retry_scheduled {
2005 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2006 }
2007
2008 if let Some((request_callback, (request, response_events))) = thread
2009 .request_callback
2010 .as_mut()
2011 .zip(request_callback_parameters.as_ref())
2012 {
2013 request_callback(request, response_events);
2014 }
2015
2016 if let Ok(initial_usage) = initial_token_usage {
2017 let usage = thread.cumulative_token_usage - initial_usage;
2018
2019 telemetry::event!(
2020 "Assistant Thread Completion",
2021 thread_id = thread.id().to_string(),
2022 prompt_id = prompt_id,
2023 model = model.telemetry_id(),
2024 model_provider = model.provider_id().to_string(),
2025 input_tokens = usage.input_tokens,
2026 output_tokens = usage.output_tokens,
2027 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2028 cache_read_input_tokens = usage.cache_read_input_tokens,
2029 );
2030 }
2031 })
2032 .ok();
2033 });
2034
2035 self.pending_completions.push(PendingCompletion {
2036 id: pending_completion_id,
2037 queue_state: QueueState::Sending,
2038 _task: task,
2039 });
2040 }
2041
2042 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2043 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2044 println!("No thread summary model");
2045 return;
2046 };
2047
2048 if !model.provider.is_authenticated(cx) {
2049 return;
2050 }
2051
2052 let request = self.to_summarize_request(
2053 &model.model,
2054 CompletionIntent::ThreadSummarization,
2055 SUMMARIZE_THREAD_PROMPT.into(),
2056 cx,
2057 );
2058
2059 self.summary = ThreadSummary::Generating;
2060
2061 self.pending_summary = cx.spawn(async move |this, cx| {
2062 let result = async {
2063 let mut messages = model.model.stream_completion(request, cx).await?;
2064
2065 let mut new_summary = String::new();
2066 while let Some(event) = messages.next().await {
2067 let Ok(event) = event else {
2068 continue;
2069 };
2070 let text = match event {
2071 LanguageModelCompletionEvent::Text(text) => text,
2072 LanguageModelCompletionEvent::StatusUpdate(
2073 CompletionRequestStatus::UsageUpdated { amount, limit },
2074 ) => {
2075 this.update(cx, |thread, cx| {
2076 thread.update_model_request_usage(amount as u32, limit, cx);
2077 })?;
2078 continue;
2079 }
2080 _ => continue,
2081 };
2082
2083 let mut lines = text.lines();
2084 new_summary.extend(lines.next());
2085
2086 // Stop if the LLM generated multiple lines.
2087 if lines.next().is_some() {
2088 break;
2089 }
2090 }
2091
2092 anyhow::Ok(new_summary)
2093 }
2094 .await;
2095
2096 this.update(cx, |this, cx| {
2097 match result {
2098 Ok(new_summary) => {
2099 if new_summary.is_empty() {
2100 this.summary = ThreadSummary::Error;
2101 } else {
2102 this.summary = ThreadSummary::Ready(new_summary.into());
2103 }
2104 }
2105 Err(err) => {
2106 this.summary = ThreadSummary::Error;
2107 log::error!("Failed to generate thread summary: {}", err);
2108 }
2109 }
2110 cx.emit(ThreadEvent::SummaryGenerated);
2111 })
2112 .log_err()?;
2113
2114 Some(())
2115 });
2116 }
2117
2118 fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2119 use LanguageModelCompletionError::*;
2120
2121 // General strategy here:
2122 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2123 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2124 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2125 match error {
2126 HttpResponseError {
2127 status_code: StatusCode::TOO_MANY_REQUESTS,
2128 ..
2129 } => Some(RetryStrategy::ExponentialBackoff {
2130 initial_delay: BASE_RETRY_DELAY,
2131 max_attempts: MAX_RETRY_ATTEMPTS,
2132 }),
2133 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2134 Some(RetryStrategy::Fixed {
2135 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2136 max_attempts: MAX_RETRY_ATTEMPTS,
2137 })
2138 }
2139 UpstreamProviderError {
2140 status,
2141 retry_after,
2142 ..
2143 } => match *status {
2144 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2145 Some(RetryStrategy::Fixed {
2146 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2147 max_attempts: MAX_RETRY_ATTEMPTS,
2148 })
2149 }
2150 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2151 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2152 // Internal Server Error could be anything, retry up to 3 times.
2153 max_attempts: 3,
2154 }),
2155 status => {
2156 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2157 // but we frequently get them in practice. See https://http.dev/529
2158 if status.as_u16() == 529 {
2159 Some(RetryStrategy::Fixed {
2160 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2161 max_attempts: MAX_RETRY_ATTEMPTS,
2162 })
2163 } else {
2164 Some(RetryStrategy::Fixed {
2165 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2166 max_attempts: 2,
2167 })
2168 }
2169 }
2170 },
2171 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2172 delay: BASE_RETRY_DELAY,
2173 max_attempts: 3,
2174 }),
2175 ApiReadResponseError { .. }
2176 | HttpSend { .. }
2177 | DeserializeResponse { .. }
2178 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2179 delay: BASE_RETRY_DELAY,
2180 max_attempts: 3,
2181 }),
2182 // Retrying these errors definitely shouldn't help.
2183 HttpResponseError {
2184 status_code:
2185 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2186 ..
2187 }
2188 | AuthenticationError { .. }
2189 | PermissionError { .. }
2190 | NoApiKey { .. }
2191 | ApiEndpointNotFound { .. }
2192 | PromptTooLarge { .. } => None,
2193 // These errors might be transient, so retry them
2194 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2195 delay: BASE_RETRY_DELAY,
2196 max_attempts: 1,
2197 }),
2198 // Retry all other 4xx and 5xx errors once.
2199 HttpResponseError { status_code, .. }
2200 if status_code.is_client_error() || status_code.is_server_error() =>
2201 {
2202 Some(RetryStrategy::Fixed {
2203 delay: BASE_RETRY_DELAY,
2204 max_attempts: 3,
2205 })
2206 }
2207 Other(err)
2208 if err.is::<PaymentRequiredError>()
2209 || err.is::<ModelRequestLimitReachedError>() =>
2210 {
2211 // Retrying won't help for Payment Required or Model Request Limit errors (where
2212 // the user must upgrade to usage-based billing to get more requests, or else wait
2213 // for a significant amount of time for the request limit to reset).
2214 None
2215 }
2216 // Conservatively assume that any other errors are non-retryable
2217 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2218 delay: BASE_RETRY_DELAY,
2219 max_attempts: 2,
2220 }),
2221 }
2222 }
2223
2224 fn handle_retryable_error_with_delay(
2225 &mut self,
2226 error: &LanguageModelCompletionError,
2227 strategy: Option<RetryStrategy>,
2228 model: Arc<dyn LanguageModel>,
2229 intent: CompletionIntent,
2230 window: Option<AnyWindowHandle>,
2231 cx: &mut Context<Self>,
2232 ) -> bool {
2233 // Store context for the Retry button
2234 self.last_error_context = Some((model.clone(), intent));
2235
2236 // Only auto-retry if Burn Mode is enabled
2237 if self.completion_mode != CompletionMode::Burn {
2238 // Show error with retry options
2239 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2240 message: format!(
2241 "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
2242 error
2243 )
2244 .into(),
2245 can_enable_burn_mode: true,
2246 }));
2247 return false;
2248 }
2249
2250 let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2251 return false;
2252 };
2253
2254 let max_attempts = match &strategy {
2255 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2256 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2257 };
2258
2259 let retry_state = self.retry_state.get_or_insert(RetryState {
2260 attempt: 0,
2261 max_attempts,
2262 intent,
2263 });
2264
2265 retry_state.attempt += 1;
2266 let attempt = retry_state.attempt;
2267 let max_attempts = retry_state.max_attempts;
2268 let intent = retry_state.intent;
2269
2270 if attempt <= max_attempts {
2271 let delay = match &strategy {
2272 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2273 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2274 Duration::from_secs(delay_secs)
2275 }
2276 RetryStrategy::Fixed { delay, .. } => *delay,
2277 };
2278
2279 // Add a transient message to inform the user
2280 let delay_secs = delay.as_secs();
2281 let retry_message = if max_attempts == 1 {
2282 format!("{error}. Retrying in {delay_secs} seconds...")
2283 } else {
2284 format!(
2285 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2286 in {delay_secs} seconds..."
2287 )
2288 };
2289 log::warn!(
2290 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2291 in {delay_secs} seconds: {error:?}",
2292 );
2293
2294 // Add a UI-only message instead of a regular message
2295 let id = self.next_message_id.post_inc();
2296 self.messages.push(Message {
2297 id,
2298 role: Role::System,
2299 segments: vec![MessageSegment::Text(retry_message)],
2300 loaded_context: LoadedContext::default(),
2301 creases: Vec::new(),
2302 is_hidden: false,
2303 ui_only: true,
2304 });
2305 cx.emit(ThreadEvent::MessageAdded(id));
2306
2307 // Schedule the retry
2308 let thread_handle = cx.entity().downgrade();
2309
2310 cx.spawn(async move |_thread, cx| {
2311 cx.background_executor().timer(delay).await;
2312
2313 thread_handle
2314 .update(cx, |thread, cx| {
2315 // Retry the completion
2316 thread.send_to_model(model, intent, window, cx);
2317 })
2318 .log_err();
2319 })
2320 .detach();
2321
2322 true
2323 } else {
2324 // Max retries exceeded
2325 self.retry_state = None;
2326
2327 // Stop generating since we're giving up on retrying.
2328 self.pending_completions.clear();
2329
2330 // Show error alongside a Retry button, but no
2331 // Enable Burn Mode button (since it's already enabled)
2332 cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2333 message: format!("Failed after retrying: {}", error).into(),
2334 can_enable_burn_mode: false,
2335 }));
2336
2337 false
2338 }
2339 }
2340
2341 pub fn start_generating_detailed_summary_if_needed(
2342 &mut self,
2343 thread_store: WeakEntity<ThreadStore>,
2344 cx: &mut Context<Self>,
2345 ) {
2346 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2347 return;
2348 };
2349
2350 match &*self.detailed_summary_rx.borrow() {
2351 DetailedSummaryState::Generating { message_id, .. }
2352 | DetailedSummaryState::Generated { message_id, .. }
2353 if *message_id == last_message_id =>
2354 {
2355 // Already up-to-date
2356 return;
2357 }
2358 _ => {}
2359 }
2360
2361 let Some(ConfiguredModel { model, provider }) =
2362 LanguageModelRegistry::read_global(cx).thread_summary_model()
2363 else {
2364 return;
2365 };
2366
2367 if !provider.is_authenticated(cx) {
2368 return;
2369 }
2370
2371 let request = self.to_summarize_request(
2372 &model,
2373 CompletionIntent::ThreadContextSummarization,
2374 SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
2375 cx,
2376 );
2377
2378 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2379 message_id: last_message_id,
2380 };
2381
2382 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2383 // be better to allow the old task to complete, but this would require logic for choosing
2384 // which result to prefer (the old task could complete after the new one, resulting in a
2385 // stale summary).
2386 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2387 let stream = model.stream_completion_text(request, cx);
2388 let Some(mut messages) = stream.await.log_err() else {
2389 thread
2390 .update(cx, |thread, _cx| {
2391 *thread.detailed_summary_tx.borrow_mut() =
2392 DetailedSummaryState::NotGenerated;
2393 })
2394 .ok()?;
2395 return None;
2396 };
2397
2398 let mut new_detailed_summary = String::new();
2399
2400 while let Some(chunk) = messages.stream.next().await {
2401 if let Some(chunk) = chunk.log_err() {
2402 new_detailed_summary.push_str(&chunk);
2403 }
2404 }
2405
2406 thread
2407 .update(cx, |thread, _cx| {
2408 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2409 text: new_detailed_summary.into(),
2410 message_id: last_message_id,
2411 };
2412 })
2413 .ok()?;
2414
2415 // Save thread so its summary can be reused later
2416 if let Some(thread) = thread.upgrade()
2417 && let Ok(Ok(save_task)) = cx.update(|cx| {
2418 thread_store
2419 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2420 })
2421 {
2422 save_task.await.log_err();
2423 }
2424
2425 Some(())
2426 });
2427 }
2428
2429 pub async fn wait_for_detailed_summary_or_text(
2430 this: &Entity<Self>,
2431 cx: &mut AsyncApp,
2432 ) -> Option<SharedString> {
2433 let mut detailed_summary_rx = this
2434 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2435 .ok()?;
2436 loop {
2437 match detailed_summary_rx.recv().await? {
2438 DetailedSummaryState::Generating { .. } => {}
2439 DetailedSummaryState::NotGenerated => {
2440 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2441 }
2442 DetailedSummaryState::Generated { text, .. } => return Some(text),
2443 }
2444 }
2445 }
2446
2447 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2448 self.detailed_summary_rx
2449 .borrow()
2450 .text()
2451 .unwrap_or_else(|| self.text().into())
2452 }
2453
2454 pub fn is_generating_detailed_summary(&self) -> bool {
2455 matches!(
2456 &*self.detailed_summary_rx.borrow(),
2457 DetailedSummaryState::Generating { .. }
2458 )
2459 }
2460
2461 pub fn use_pending_tools(
2462 &mut self,
2463 window: Option<AnyWindowHandle>,
2464 model: Arc<dyn LanguageModel>,
2465 cx: &mut Context<Self>,
2466 ) -> Vec<PendingToolUse> {
2467 let request =
2468 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2469 let pending_tool_uses = self
2470 .tool_use
2471 .pending_tool_uses()
2472 .into_iter()
2473 .filter(|tool_use| tool_use.status.is_idle())
2474 .cloned()
2475 .collect::<Vec<_>>();
2476
2477 for tool_use in pending_tool_uses.iter() {
2478 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2479 }
2480
2481 pending_tool_uses
2482 }
2483
2484 fn use_pending_tool(
2485 &mut self,
2486 tool_use: PendingToolUse,
2487 request: Arc<LanguageModelRequest>,
2488 model: Arc<dyn LanguageModel>,
2489 window: Option<AnyWindowHandle>,
2490 cx: &mut Context<Self>,
2491 ) {
2492 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2493 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2494 };
2495
2496 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2497 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2498 }
2499
2500 if tool.needs_confirmation(&tool_use.input, &self.project, cx)
2501 && !AgentSettings::get_global(cx).always_allow_tool_actions
2502 {
2503 self.tool_use.confirm_tool_use(
2504 tool_use.id,
2505 tool_use.ui_text,
2506 tool_use.input,
2507 request,
2508 tool,
2509 );
2510 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2511 } else {
2512 self.run_tool(
2513 tool_use.id,
2514 tool_use.ui_text,
2515 tool_use.input,
2516 request,
2517 tool,
2518 model,
2519 window,
2520 cx,
2521 );
2522 }
2523 }
2524
2525 pub fn handle_hallucinated_tool_use(
2526 &mut self,
2527 tool_use_id: LanguageModelToolUseId,
2528 hallucinated_tool_name: Arc<str>,
2529 window: Option<AnyWindowHandle>,
2530 cx: &mut Context<Thread>,
2531 ) {
2532 let available_tools = self.profile.enabled_tools(cx);
2533
2534 let tool_list = available_tools
2535 .iter()
2536 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2537 .collect::<Vec<_>>()
2538 .join("\n");
2539
2540 let error_message = format!(
2541 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2542 hallucinated_tool_name, tool_list
2543 );
2544
2545 let pending_tool_use = self.tool_use.insert_tool_output(
2546 tool_use_id.clone(),
2547 hallucinated_tool_name,
2548 Err(anyhow!("Missing tool call: {error_message}")),
2549 self.configured_model.as_ref(),
2550 self.completion_mode,
2551 );
2552
2553 cx.emit(ThreadEvent::MissingToolUse {
2554 tool_use_id: tool_use_id.clone(),
2555 ui_text: error_message.into(),
2556 });
2557
2558 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2559 }
2560
2561 pub fn receive_invalid_tool_json(
2562 &mut self,
2563 tool_use_id: LanguageModelToolUseId,
2564 tool_name: Arc<str>,
2565 invalid_json: Arc<str>,
2566 error: String,
2567 window: Option<AnyWindowHandle>,
2568 cx: &mut Context<Thread>,
2569 ) {
2570 log::error!("The model returned invalid input JSON: {invalid_json}");
2571
2572 let pending_tool_use = self.tool_use.insert_tool_output(
2573 tool_use_id.clone(),
2574 tool_name,
2575 Err(anyhow!("Error parsing input JSON: {error}")),
2576 self.configured_model.as_ref(),
2577 self.completion_mode,
2578 );
2579 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2580 pending_tool_use.ui_text.clone()
2581 } else {
2582 log::error!(
2583 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2584 );
2585 format!("Unknown tool {}", tool_use_id).into()
2586 };
2587
2588 cx.emit(ThreadEvent::InvalidToolInput {
2589 tool_use_id: tool_use_id.clone(),
2590 ui_text,
2591 invalid_input_json: invalid_json,
2592 });
2593
2594 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2595 }
2596
2597 pub fn run_tool(
2598 &mut self,
2599 tool_use_id: LanguageModelToolUseId,
2600 ui_text: impl Into<SharedString>,
2601 input: serde_json::Value,
2602 request: Arc<LanguageModelRequest>,
2603 tool: Arc<dyn Tool>,
2604 model: Arc<dyn LanguageModel>,
2605 window: Option<AnyWindowHandle>,
2606 cx: &mut Context<Thread>,
2607 ) {
2608 let task =
2609 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2610 self.tool_use
2611 .run_pending_tool(tool_use_id, ui_text.into(), task);
2612 }
2613
2614 fn spawn_tool_use(
2615 &mut self,
2616 tool_use_id: LanguageModelToolUseId,
2617 request: Arc<LanguageModelRequest>,
2618 input: serde_json::Value,
2619 tool: Arc<dyn Tool>,
2620 model: Arc<dyn LanguageModel>,
2621 window: Option<AnyWindowHandle>,
2622 cx: &mut Context<Thread>,
2623 ) -> Task<()> {
2624 let tool_name: Arc<str> = tool.name().into();
2625
2626 let tool_result = tool.run(
2627 input,
2628 request,
2629 self.project.clone(),
2630 self.action_log.clone(),
2631 model,
2632 window,
2633 cx,
2634 );
2635
2636 // Store the card separately if it exists
2637 if let Some(card) = tool_result.card.clone() {
2638 self.tool_use
2639 .insert_tool_result_card(tool_use_id.clone(), card);
2640 }
2641
2642 cx.spawn({
2643 async move |thread: WeakEntity<Thread>, cx| {
2644 let output = tool_result.output.await;
2645
2646 thread
2647 .update(cx, |thread, cx| {
2648 let pending_tool_use = thread.tool_use.insert_tool_output(
2649 tool_use_id.clone(),
2650 tool_name,
2651 output,
2652 thread.configured_model.as_ref(),
2653 thread.completion_mode,
2654 );
2655 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2656 })
2657 .ok();
2658 }
2659 })
2660 }
2661
2662 fn tool_finished(
2663 &mut self,
2664 tool_use_id: LanguageModelToolUseId,
2665 pending_tool_use: Option<PendingToolUse>,
2666 canceled: bool,
2667 window: Option<AnyWindowHandle>,
2668 cx: &mut Context<Self>,
2669 ) {
2670 if self.all_tools_finished()
2671 && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
2672 && !canceled
2673 {
2674 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2675 }
2676
2677 cx.emit(ThreadEvent::ToolFinished {
2678 tool_use_id,
2679 pending_tool_use,
2680 });
2681 }
2682
2683 /// Cancels the last pending completion, if there are any pending.
2684 ///
2685 /// Returns whether a completion was canceled.
2686 pub fn cancel_last_completion(
2687 &mut self,
2688 window: Option<AnyWindowHandle>,
2689 cx: &mut Context<Self>,
2690 ) -> bool {
2691 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2692
2693 self.retry_state = None;
2694
2695 for pending_tool_use in self.tool_use.cancel_pending() {
2696 canceled = true;
2697 self.tool_finished(
2698 pending_tool_use.id.clone(),
2699 Some(pending_tool_use),
2700 true,
2701 window,
2702 cx,
2703 );
2704 }
2705
2706 if canceled {
2707 cx.emit(ThreadEvent::CompletionCanceled);
2708
2709 // When canceled, we always want to insert the checkpoint.
2710 // (We skip over finalize_pending_checkpoint, because it
2711 // would conclude we didn't have anything to insert here.)
2712 if let Some(checkpoint) = self.pending_checkpoint.take() {
2713 self.insert_checkpoint(checkpoint, cx);
2714 }
2715 } else {
2716 self.finalize_pending_checkpoint(cx);
2717 }
2718
2719 canceled
2720 }
2721
2722 /// Signals that any in-progress editing should be canceled.
2723 ///
2724 /// This method is used to notify listeners (like ActiveThread) that
2725 /// they should cancel any editing operations.
2726 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2727 cx.emit(ThreadEvent::CancelEditing);
2728 }
2729
2730 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2731 self.message_feedback.get(&message_id).copied()
2732 }
2733
2734 pub fn report_message_feedback(
2735 &mut self,
2736 message_id: MessageId,
2737 feedback: ThreadFeedback,
2738 cx: &mut Context<Self>,
2739 ) -> Task<Result<()>> {
2740 if self.message_feedback.get(&message_id) == Some(&feedback) {
2741 return Task::ready(Ok(()));
2742 }
2743
2744 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2745 let serialized_thread = self.serialize(cx);
2746 let thread_id = self.id().clone();
2747 let client = self.project.read(cx).client();
2748
2749 let enabled_tool_names: Vec<String> = self
2750 .profile
2751 .enabled_tools(cx)
2752 .iter()
2753 .map(|(name, _)| name.clone().into())
2754 .collect();
2755
2756 self.message_feedback.insert(message_id, feedback);
2757
2758 cx.notify();
2759
2760 let message_content = self
2761 .message(message_id)
2762 .map(|msg| msg.to_message_content())
2763 .unwrap_or_default();
2764
2765 cx.background_spawn(async move {
2766 let final_project_snapshot = final_project_snapshot.await;
2767 let serialized_thread = serialized_thread.await?;
2768 let thread_data =
2769 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2770
2771 let rating = match feedback {
2772 ThreadFeedback::Positive => "positive",
2773 ThreadFeedback::Negative => "negative",
2774 };
2775 telemetry::event!(
2776 "Assistant Thread Rated",
2777 rating,
2778 thread_id,
2779 enabled_tool_names,
2780 message_id = message_id.0,
2781 message_content,
2782 thread_data,
2783 final_project_snapshot
2784 );
2785 client.telemetry().flush_events().await;
2786
2787 Ok(())
2788 })
2789 }
2790
2791 /// Create a snapshot of the current project state including git information and unsaved buffers.
2792 fn project_snapshot(
2793 project: Entity<Project>,
2794 cx: &mut Context<Self>,
2795 ) -> Task<Arc<ProjectSnapshot>> {
2796 let git_store = project.read(cx).git_store().clone();
2797 let worktree_snapshots: Vec<_> = project
2798 .read(cx)
2799 .visible_worktrees(cx)
2800 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2801 .collect();
2802
2803 cx.spawn(async move |_, _| {
2804 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2805
2806 Arc::new(ProjectSnapshot {
2807 worktree_snapshots,
2808 timestamp: Utc::now(),
2809 })
2810 })
2811 }
2812
2813 fn worktree_snapshot(
2814 worktree: Entity<project::Worktree>,
2815 git_store: Entity<GitStore>,
2816 cx: &App,
2817 ) -> Task<WorktreeSnapshot> {
2818 cx.spawn(async move |cx| {
2819 // Get worktree path and snapshot
2820 let worktree_info = cx.update(|app_cx| {
2821 let worktree = worktree.read(app_cx);
2822 let path = worktree.abs_path().to_string_lossy().into_owned();
2823 let snapshot = worktree.snapshot();
2824 (path, snapshot)
2825 });
2826
2827 let Ok((worktree_path, _snapshot)) = worktree_info else {
2828 return WorktreeSnapshot {
2829 worktree_path: String::new(),
2830 git_state: None,
2831 };
2832 };
2833
2834 let git_state = git_store
2835 .update(cx, |git_store, cx| {
2836 git_store
2837 .repositories()
2838 .values()
2839 .find(|repo| {
2840 repo.read(cx)
2841 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2842 .is_some()
2843 })
2844 .cloned()
2845 })
2846 .ok()
2847 .flatten()
2848 .map(|repo| {
2849 repo.update(cx, |repo, _| {
2850 let current_branch =
2851 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2852 repo.send_job(None, |state, _| async move {
2853 let RepositoryState::Local { backend, .. } = state else {
2854 return GitState {
2855 remote_url: None,
2856 head_sha: None,
2857 current_branch,
2858 diff: None,
2859 };
2860 };
2861
2862 let remote_url = backend.remote_url("origin");
2863 let head_sha = backend.head_sha().await;
2864 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2865
2866 GitState {
2867 remote_url,
2868 head_sha,
2869 current_branch,
2870 diff,
2871 }
2872 })
2873 })
2874 });
2875
2876 let git_state = match git_state {
2877 Some(git_state) => match git_state.ok() {
2878 Some(git_state) => git_state.await.ok(),
2879 None => None,
2880 },
2881 None => None,
2882 };
2883
2884 WorktreeSnapshot {
2885 worktree_path,
2886 git_state,
2887 }
2888 })
2889 }
2890
2891 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2892 let mut markdown = Vec::new();
2893
2894 let summary = self.summary().or_default();
2895 writeln!(markdown, "# {summary}\n")?;
2896
2897 for message in self.messages() {
2898 writeln!(
2899 markdown,
2900 "## {role}\n",
2901 role = match message.role {
2902 Role::User => "User",
2903 Role::Assistant => "Agent",
2904 Role::System => "System",
2905 }
2906 )?;
2907
2908 if !message.loaded_context.text.is_empty() {
2909 writeln!(markdown, "{}", message.loaded_context.text)?;
2910 }
2911
2912 if !message.loaded_context.images.is_empty() {
2913 writeln!(
2914 markdown,
2915 "\n{} images attached as context.\n",
2916 message.loaded_context.images.len()
2917 )?;
2918 }
2919
2920 for segment in &message.segments {
2921 match segment {
2922 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2923 MessageSegment::Thinking { text, .. } => {
2924 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2925 }
2926 MessageSegment::RedactedThinking(_) => {}
2927 }
2928 }
2929
2930 for tool_use in self.tool_uses_for_message(message.id, cx) {
2931 writeln!(
2932 markdown,
2933 "**Use Tool: {} ({})**",
2934 tool_use.name, tool_use.id
2935 )?;
2936 writeln!(markdown, "```json")?;
2937 writeln!(
2938 markdown,
2939 "{}",
2940 serde_json::to_string_pretty(&tool_use.input)?
2941 )?;
2942 writeln!(markdown, "```")?;
2943 }
2944
2945 for tool_result in self.tool_results_for_message(message.id) {
2946 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2947 if tool_result.is_error {
2948 write!(markdown, " (Error)")?;
2949 }
2950
2951 writeln!(markdown, "**\n")?;
2952 match &tool_result.content {
2953 LanguageModelToolResultContent::Text(text) => {
2954 writeln!(markdown, "{text}")?;
2955 }
2956 LanguageModelToolResultContent::Image(image) => {
2957 writeln!(markdown, "", image.source)?;
2958 }
2959 }
2960
2961 if let Some(output) = tool_result.output.as_ref() {
2962 writeln!(
2963 markdown,
2964 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2965 serde_json::to_string_pretty(output)?
2966 )?;
2967 }
2968 }
2969 }
2970
2971 Ok(String::from_utf8_lossy(&markdown).to_string())
2972 }
2973
2974 pub fn keep_edits_in_range(
2975 &mut self,
2976 buffer: Entity<language::Buffer>,
2977 buffer_range: Range<language::Anchor>,
2978 cx: &mut Context<Self>,
2979 ) {
2980 self.action_log.update(cx, |action_log, cx| {
2981 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2982 });
2983 }
2984
2985 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2986 self.action_log
2987 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2988 }
2989
2990 pub fn reject_edits_in_ranges(
2991 &mut self,
2992 buffer: Entity<language::Buffer>,
2993 buffer_ranges: Vec<Range<language::Anchor>>,
2994 cx: &mut Context<Self>,
2995 ) -> Task<Result<()>> {
2996 self.action_log.update(cx, |action_log, cx| {
2997 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2998 })
2999 }
3000
3001 pub fn action_log(&self) -> &Entity<ActionLog> {
3002 &self.action_log
3003 }
3004
3005 pub fn project(&self) -> &Entity<Project> {
3006 &self.project
3007 }
3008
3009 pub fn cumulative_token_usage(&self) -> TokenUsage {
3010 self.cumulative_token_usage
3011 }
3012
3013 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3014 let Some(model) = self.configured_model.as_ref() else {
3015 return TotalTokenUsage::default();
3016 };
3017
3018 let max = model
3019 .model
3020 .max_token_count_for_mode(self.completion_mode().into());
3021
3022 let index = self
3023 .messages
3024 .iter()
3025 .position(|msg| msg.id == message_id)
3026 .unwrap_or(0);
3027
3028 if index == 0 {
3029 return TotalTokenUsage { total: 0, max };
3030 }
3031
3032 let token_usage = &self
3033 .request_token_usage
3034 .get(index - 1)
3035 .cloned()
3036 .unwrap_or_default();
3037
3038 TotalTokenUsage {
3039 total: token_usage.total_tokens(),
3040 max,
3041 }
3042 }
3043
3044 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3045 let model = self.configured_model.as_ref()?;
3046
3047 let max = model
3048 .model
3049 .max_token_count_for_mode(self.completion_mode().into());
3050
3051 if let Some(exceeded_error) = &self.exceeded_window_error
3052 && model.model.id() == exceeded_error.model_id
3053 {
3054 return Some(TotalTokenUsage {
3055 total: exceeded_error.token_count,
3056 max,
3057 });
3058 }
3059
3060 let total = self
3061 .token_usage_at_last_message()
3062 .unwrap_or_default()
3063 .total_tokens();
3064
3065 Some(TotalTokenUsage { total, max })
3066 }
3067
3068 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3069 self.request_token_usage
3070 .get(self.messages.len().saturating_sub(1))
3071 .or_else(|| self.request_token_usage.last())
3072 .cloned()
3073 }
3074
3075 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3076 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3077 self.request_token_usage
3078 .resize(self.messages.len(), placeholder);
3079
3080 if let Some(last) = self.request_token_usage.last_mut() {
3081 *last = token_usage;
3082 }
3083 }
3084
3085 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3086 self.project
3087 .read(cx)
3088 .user_store()
3089 .update(cx, |user_store, cx| {
3090 user_store.update_model_request_usage(
3091 ModelRequestUsage(RequestUsage {
3092 amount: amount as i32,
3093 limit,
3094 }),
3095 cx,
3096 )
3097 });
3098 }
3099
3100 pub fn deny_tool_use(
3101 &mut self,
3102 tool_use_id: LanguageModelToolUseId,
3103 tool_name: Arc<str>,
3104 window: Option<AnyWindowHandle>,
3105 cx: &mut Context<Self>,
3106 ) {
3107 let err = Err(anyhow::anyhow!(
3108 "Permission to run tool action denied by user"
3109 ));
3110
3111 self.tool_use.insert_tool_output(
3112 tool_use_id.clone(),
3113 tool_name,
3114 err,
3115 self.configured_model.as_ref(),
3116 self.completion_mode,
3117 );
3118 self.tool_finished(tool_use_id, None, true, window, cx);
3119 }
3120}
3121
3122#[derive(Debug, Clone, Error)]
3123pub enum ThreadError {
3124 #[error("Payment required")]
3125 PaymentRequired,
3126 #[error("Model request limit reached")]
3127 ModelRequestLimitReached { plan: Plan },
3128 #[error("Message {header}: {message}")]
3129 Message {
3130 header: SharedString,
3131 message: SharedString,
3132 },
3133 #[error("Retryable error: {message}")]
3134 RetryableError {
3135 message: SharedString,
3136 can_enable_burn_mode: bool,
3137 },
3138}
3139
3140#[derive(Debug, Clone)]
3141pub enum ThreadEvent {
3142 ShowError(ThreadError),
3143 StreamedCompletion,
3144 ReceivedTextChunk,
3145 NewRequest,
3146 StreamedAssistantText(MessageId, String),
3147 StreamedAssistantThinking(MessageId, String),
3148 StreamedToolUse {
3149 tool_use_id: LanguageModelToolUseId,
3150 ui_text: Arc<str>,
3151 input: serde_json::Value,
3152 },
3153 MissingToolUse {
3154 tool_use_id: LanguageModelToolUseId,
3155 ui_text: Arc<str>,
3156 },
3157 InvalidToolInput {
3158 tool_use_id: LanguageModelToolUseId,
3159 ui_text: Arc<str>,
3160 invalid_input_json: Arc<str>,
3161 },
3162 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3163 MessageAdded(MessageId),
3164 MessageEdited(MessageId),
3165 MessageDeleted(MessageId),
3166 SummaryGenerated,
3167 SummaryChanged,
3168 UsePendingTools {
3169 tool_uses: Vec<PendingToolUse>,
3170 },
3171 ToolFinished {
3172 #[allow(unused)]
3173 tool_use_id: LanguageModelToolUseId,
3174 /// The pending tool use that corresponds to this tool.
3175 pending_tool_use: Option<PendingToolUse>,
3176 },
3177 CheckpointChanged,
3178 ToolConfirmationNeeded,
3179 ToolUseLimitReached,
3180 CancelEditing,
3181 CompletionCanceled,
3182 ProfileChanged,
3183}
3184
3185impl EventEmitter<ThreadEvent> for Thread {}
3186
3187struct PendingCompletion {
3188 id: usize,
3189 queue_state: QueueState,
3190 _task: Task<()>,
3191}
3192
3193#[cfg(test)]
3194mod tests {
3195 use super::*;
3196 use crate::{
3197 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3198 };
3199
3200 // Test-specific constants
3201 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3202 use agent_settings::{AgentProfileId, AgentSettings};
3203 use assistant_tool::ToolRegistry;
3204 use assistant_tools;
3205 use fs::Fs;
3206 use futures::StreamExt;
3207 use futures::future::BoxFuture;
3208 use futures::stream::BoxStream;
3209 use gpui::TestAppContext;
3210 use http_client;
3211 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3212 use language_model::{
3213 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3214 LanguageModelProviderName, LanguageModelToolChoice,
3215 };
3216 use parking_lot::Mutex;
3217 use project::{FakeFs, Project};
3218 use prompt_store::PromptBuilder;
3219 use serde_json::json;
3220 use settings::{LanguageModelParameters, Settings, SettingsStore};
3221 use std::sync::Arc;
3222 use std::time::Duration;
3223 use theme::ThemeSettings;
3224 use util::path;
3225 use workspace::Workspace;
3226
3227 #[gpui::test]
3228 async fn test_message_with_context(cx: &mut TestAppContext) {
3229 let fs = init_test_settings(cx);
3230
3231 let project = create_test_project(
3232 &fs,
3233 cx,
3234 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3235 )
3236 .await;
3237
3238 let (_workspace, _thread_store, thread, context_store, model) =
3239 setup_test_environment(cx, project.clone()).await;
3240
3241 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3242 .await
3243 .unwrap();
3244
3245 let context =
3246 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3247 let loaded_context = cx
3248 .update(|cx| load_context(vec![context], &project, &None, cx))
3249 .await;
3250
3251 // Insert user message with context
3252 let message_id = thread.update(cx, |thread, cx| {
3253 thread.insert_user_message(
3254 "Please explain this code",
3255 loaded_context,
3256 None,
3257 Vec::new(),
3258 cx,
3259 )
3260 });
3261
3262 // Check content and context in message object
3263 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3264
3265 // Use different path format strings based on platform for the test
3266 #[cfg(windows)]
3267 let path_part = r"test\code.rs";
3268 #[cfg(not(windows))]
3269 let path_part = "test/code.rs";
3270
3271 let expected_context = format!(
3272 r#"
3273<context>
3274The following items were attached by the user. They are up-to-date and don't need to be re-read.
3275
3276<files>
3277```rs {path_part}
3278fn main() {{
3279 println!("Hello, world!");
3280}}
3281```
3282</files>
3283</context>
3284"#
3285 );
3286
3287 assert_eq!(message.role, Role::User);
3288 assert_eq!(message.segments.len(), 1);
3289 assert_eq!(
3290 message.segments[0],
3291 MessageSegment::Text("Please explain this code".to_string())
3292 );
3293 assert_eq!(message.loaded_context.text, expected_context);
3294
3295 // Check message in request
3296 let request = thread.update(cx, |thread, cx| {
3297 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3298 });
3299
3300 assert_eq!(request.messages.len(), 2);
3301 let expected_full_message = format!("{}Please explain this code", expected_context);
3302 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3303 }
3304
3305 #[gpui::test]
3306 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3307 let fs = init_test_settings(cx);
3308
3309 let project = create_test_project(
3310 &fs,
3311 cx,
3312 json!({
3313 "file1.rs": "fn function1() {}\n",
3314 "file2.rs": "fn function2() {}\n",
3315 "file3.rs": "fn function3() {}\n",
3316 "file4.rs": "fn function4() {}\n",
3317 }),
3318 )
3319 .await;
3320
3321 let (_, _thread_store, thread, context_store, model) =
3322 setup_test_environment(cx, project.clone()).await;
3323
3324 // First message with context 1
3325 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3326 .await
3327 .unwrap();
3328 let new_contexts = context_store.update(cx, |store, cx| {
3329 store.new_context_for_thread(thread.read(cx), None)
3330 });
3331 assert_eq!(new_contexts.len(), 1);
3332 let loaded_context = cx
3333 .update(|cx| load_context(new_contexts, &project, &None, cx))
3334 .await;
3335 let message1_id = thread.update(cx, |thread, cx| {
3336 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3337 });
3338
3339 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3340 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3341 .await
3342 .unwrap();
3343 let new_contexts = context_store.update(cx, |store, cx| {
3344 store.new_context_for_thread(thread.read(cx), None)
3345 });
3346 assert_eq!(new_contexts.len(), 1);
3347 let loaded_context = cx
3348 .update(|cx| load_context(new_contexts, &project, &None, cx))
3349 .await;
3350 let message2_id = thread.update(cx, |thread, cx| {
3351 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3352 });
3353
3354 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3355 //
3356 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3357 .await
3358 .unwrap();
3359 let new_contexts = context_store.update(cx, |store, cx| {
3360 store.new_context_for_thread(thread.read(cx), None)
3361 });
3362 assert_eq!(new_contexts.len(), 1);
3363 let loaded_context = cx
3364 .update(|cx| load_context(new_contexts, &project, &None, cx))
3365 .await;
3366 let message3_id = thread.update(cx, |thread, cx| {
3367 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3368 });
3369
3370 // Check what contexts are included in each message
3371 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3372 (
3373 thread.message(message1_id).unwrap().clone(),
3374 thread.message(message2_id).unwrap().clone(),
3375 thread.message(message3_id).unwrap().clone(),
3376 )
3377 });
3378
3379 // First message should include context 1
3380 assert!(message1.loaded_context.text.contains("file1.rs"));
3381
3382 // Second message should include only context 2 (not 1)
3383 assert!(!message2.loaded_context.text.contains("file1.rs"));
3384 assert!(message2.loaded_context.text.contains("file2.rs"));
3385
3386 // Third message should include only context 3 (not 1 or 2)
3387 assert!(!message3.loaded_context.text.contains("file1.rs"));
3388 assert!(!message3.loaded_context.text.contains("file2.rs"));
3389 assert!(message3.loaded_context.text.contains("file3.rs"));
3390
3391 // Check entire request to make sure all contexts are properly included
3392 let request = thread.update(cx, |thread, cx| {
3393 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3394 });
3395
3396 // The request should contain all 3 messages
3397 assert_eq!(request.messages.len(), 4);
3398
3399 // Check that the contexts are properly formatted in each message
3400 assert!(request.messages[1].string_contents().contains("file1.rs"));
3401 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3402 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3403
3404 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3405 assert!(request.messages[2].string_contents().contains("file2.rs"));
3406 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3407
3408 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3409 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3410 assert!(request.messages[3].string_contents().contains("file3.rs"));
3411
3412 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3413 .await
3414 .unwrap();
3415 let new_contexts = context_store.update(cx, |store, cx| {
3416 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3417 });
3418 assert_eq!(new_contexts.len(), 3);
3419 let loaded_context = cx
3420 .update(|cx| load_context(new_contexts, &project, &None, cx))
3421 .await
3422 .loaded_context;
3423
3424 assert!(!loaded_context.text.contains("file1.rs"));
3425 assert!(loaded_context.text.contains("file2.rs"));
3426 assert!(loaded_context.text.contains("file3.rs"));
3427 assert!(loaded_context.text.contains("file4.rs"));
3428
3429 let new_contexts = context_store.update(cx, |store, cx| {
3430 // Remove file4.rs
3431 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3432 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3433 });
3434 assert_eq!(new_contexts.len(), 2);
3435 let loaded_context = cx
3436 .update(|cx| load_context(new_contexts, &project, &None, cx))
3437 .await
3438 .loaded_context;
3439
3440 assert!(!loaded_context.text.contains("file1.rs"));
3441 assert!(loaded_context.text.contains("file2.rs"));
3442 assert!(loaded_context.text.contains("file3.rs"));
3443 assert!(!loaded_context.text.contains("file4.rs"));
3444
3445 let new_contexts = context_store.update(cx, |store, cx| {
3446 // Remove file3.rs
3447 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3448 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3449 });
3450 assert_eq!(new_contexts.len(), 1);
3451 let loaded_context = cx
3452 .update(|cx| load_context(new_contexts, &project, &None, cx))
3453 .await
3454 .loaded_context;
3455
3456 assert!(!loaded_context.text.contains("file1.rs"));
3457 assert!(loaded_context.text.contains("file2.rs"));
3458 assert!(!loaded_context.text.contains("file3.rs"));
3459 assert!(!loaded_context.text.contains("file4.rs"));
3460 }
3461
3462 #[gpui::test]
3463 async fn test_message_without_files(cx: &mut TestAppContext) {
3464 let fs = init_test_settings(cx);
3465
3466 let project = create_test_project(
3467 &fs,
3468 cx,
3469 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3470 )
3471 .await;
3472
3473 let (_, _thread_store, thread, _context_store, model) =
3474 setup_test_environment(cx, project.clone()).await;
3475
3476 // Insert user message without any context (empty context vector)
3477 let message_id = thread.update(cx, |thread, cx| {
3478 thread.insert_user_message(
3479 "What is the best way to learn Rust?",
3480 ContextLoadResult::default(),
3481 None,
3482 Vec::new(),
3483 cx,
3484 )
3485 });
3486
3487 // Check content and context in message object
3488 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3489
3490 // Context should be empty when no files are included
3491 assert_eq!(message.role, Role::User);
3492 assert_eq!(message.segments.len(), 1);
3493 assert_eq!(
3494 message.segments[0],
3495 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3496 );
3497 assert_eq!(message.loaded_context.text, "");
3498
3499 // Check message in request
3500 let request = thread.update(cx, |thread, cx| {
3501 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3502 });
3503
3504 assert_eq!(request.messages.len(), 2);
3505 assert_eq!(
3506 request.messages[1].string_contents(),
3507 "What is the best way to learn Rust?"
3508 );
3509
3510 // Add second message, also without context
3511 let message2_id = thread.update(cx, |thread, cx| {
3512 thread.insert_user_message(
3513 "Are there any good books?",
3514 ContextLoadResult::default(),
3515 None,
3516 Vec::new(),
3517 cx,
3518 )
3519 });
3520
3521 let message2 =
3522 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3523 assert_eq!(message2.loaded_context.text, "");
3524
3525 // Check that both messages appear in the request
3526 let request = thread.update(cx, |thread, cx| {
3527 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3528 });
3529
3530 assert_eq!(request.messages.len(), 3);
3531 assert_eq!(
3532 request.messages[1].string_contents(),
3533 "What is the best way to learn Rust?"
3534 );
3535 assert_eq!(
3536 request.messages[2].string_contents(),
3537 "Are there any good books?"
3538 );
3539 }
3540
3541 #[gpui::test]
3542 #[ignore] // turn this test on when project_notifications tool is re-enabled
3543 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3544 let fs = init_test_settings(cx);
3545
3546 let project = create_test_project(
3547 &fs,
3548 cx,
3549 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3550 )
3551 .await;
3552
3553 let (_workspace, _thread_store, thread, context_store, model) =
3554 setup_test_environment(cx, project.clone()).await;
3555
3556 // Add a buffer to the context. This will be a tracked buffer
3557 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3558 .await
3559 .unwrap();
3560
3561 let context = context_store
3562 .read_with(cx, |store, _| store.context().next().cloned())
3563 .unwrap();
3564 let loaded_context = cx
3565 .update(|cx| load_context(vec![context], &project, &None, cx))
3566 .await;
3567
3568 // Insert user message and assistant response
3569 thread.update(cx, |thread, cx| {
3570 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3571 thread.insert_assistant_message(
3572 vec![MessageSegment::Text("This code prints 42.".into())],
3573 cx,
3574 );
3575 });
3576 cx.run_until_parked();
3577
3578 // We shouldn't have a stale buffer notification yet
3579 let notifications = thread.read_with(cx, |thread, _| {
3580 find_tool_uses(thread, "project_notifications")
3581 });
3582 assert!(
3583 notifications.is_empty(),
3584 "Should not have stale buffer notification before buffer is modified"
3585 );
3586
3587 // Modify the buffer
3588 buffer.update(cx, |buffer, cx| {
3589 buffer.edit(
3590 [(1..1, "\n println!(\"Added a new line\");\n")],
3591 None,
3592 cx,
3593 );
3594 });
3595
3596 // Insert another user message
3597 thread.update(cx, |thread, cx| {
3598 thread.insert_user_message(
3599 "What does the code do now?",
3600 ContextLoadResult::default(),
3601 None,
3602 Vec::new(),
3603 cx,
3604 )
3605 });
3606 cx.run_until_parked();
3607
3608 // Check for the stale buffer warning
3609 thread.update(cx, |thread, cx| {
3610 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3611 });
3612 cx.run_until_parked();
3613
3614 let notifications = thread.read_with(cx, |thread, _cx| {
3615 find_tool_uses(thread, "project_notifications")
3616 });
3617
3618 let [notification] = notifications.as_slice() else {
3619 panic!("Should have a `project_notifications` tool use");
3620 };
3621
3622 let Some(notification_content) = notification.content.to_str() else {
3623 panic!("`project_notifications` should return text");
3624 };
3625
3626 assert!(notification_content.contains("These files have changed since the last read:"));
3627 assert!(notification_content.contains("code.rs"));
3628
3629 // Insert another user message and flush notifications again
3630 thread.update(cx, |thread, cx| {
3631 thread.insert_user_message(
3632 "Can you tell me more?",
3633 ContextLoadResult::default(),
3634 None,
3635 Vec::new(),
3636 cx,
3637 )
3638 });
3639
3640 thread.update(cx, |thread, cx| {
3641 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3642 });
3643 cx.run_until_parked();
3644
3645 // There should be no new notifications (we already flushed one)
3646 let notifications = thread.read_with(cx, |thread, _cx| {
3647 find_tool_uses(thread, "project_notifications")
3648 });
3649
3650 assert_eq!(
3651 notifications.len(),
3652 1,
3653 "Should still have only one notification after second flush - no duplicates"
3654 );
3655 }
3656
3657 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3658 thread
3659 .messages()
3660 .flat_map(|message| {
3661 thread
3662 .tool_results_for_message(message.id)
3663 .into_iter()
3664 .filter(|result| result.tool_name == tool_name.into())
3665 .cloned()
3666 .collect::<Vec<_>>()
3667 })
3668 .collect()
3669 }
3670
3671 #[gpui::test]
3672 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3673 let fs = init_test_settings(cx);
3674
3675 let project = create_test_project(
3676 &fs,
3677 cx,
3678 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3679 )
3680 .await;
3681
3682 let (_workspace, thread_store, thread, _context_store, _model) =
3683 setup_test_environment(cx, project.clone()).await;
3684
3685 // Check that we are starting with the default profile
3686 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3687 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3688 assert_eq!(
3689 profile,
3690 AgentProfile::new(AgentProfileId::default(), tool_set)
3691 );
3692 }
3693
3694 #[gpui::test]
3695 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3696 let fs = init_test_settings(cx);
3697
3698 let project = create_test_project(
3699 &fs,
3700 cx,
3701 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3702 )
3703 .await;
3704
3705 let (_workspace, thread_store, thread, _context_store, _model) =
3706 setup_test_environment(cx, project.clone()).await;
3707
3708 // Profile gets serialized with default values
3709 let serialized = thread
3710 .update(cx, |thread, cx| thread.serialize(cx))
3711 .await
3712 .unwrap();
3713
3714 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3715
3716 let deserialized = cx.update(|cx| {
3717 thread.update(cx, |thread, cx| {
3718 Thread::deserialize(
3719 thread.id.clone(),
3720 serialized,
3721 thread.project.clone(),
3722 thread.tools.clone(),
3723 thread.prompt_builder.clone(),
3724 thread.project_context.clone(),
3725 None,
3726 cx,
3727 )
3728 })
3729 });
3730 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3731
3732 assert_eq!(
3733 deserialized.profile,
3734 AgentProfile::new(AgentProfileId::default(), tool_set)
3735 );
3736 }
3737
3738 #[gpui::test]
3739 async fn test_temperature_setting(cx: &mut TestAppContext) {
3740 let fs = init_test_settings(cx);
3741
3742 let project = create_test_project(
3743 &fs,
3744 cx,
3745 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3746 )
3747 .await;
3748
3749 let (_workspace, _thread_store, thread, _context_store, model) =
3750 setup_test_environment(cx, project.clone()).await;
3751
3752 // Both model and provider
3753 cx.update(|cx| {
3754 AgentSettings::override_global(
3755 AgentSettings {
3756 model_parameters: vec![LanguageModelParameters {
3757 provider: Some(model.provider_id().0.to_string().into()),
3758 model: Some(model.id().0),
3759 temperature: Some(0.66),
3760 }],
3761 ..AgentSettings::get_global(cx).clone()
3762 },
3763 cx,
3764 );
3765 });
3766
3767 let request = thread.update(cx, |thread, cx| {
3768 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3769 });
3770 assert_eq!(request.temperature, Some(0.66));
3771
3772 // Only model
3773 cx.update(|cx| {
3774 AgentSettings::override_global(
3775 AgentSettings {
3776 model_parameters: vec![LanguageModelParameters {
3777 provider: None,
3778 model: Some(model.id().0),
3779 temperature: Some(0.66),
3780 }],
3781 ..AgentSettings::get_global(cx).clone()
3782 },
3783 cx,
3784 );
3785 });
3786
3787 let request = thread.update(cx, |thread, cx| {
3788 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3789 });
3790 assert_eq!(request.temperature, Some(0.66));
3791
3792 // Only provider
3793 cx.update(|cx| {
3794 AgentSettings::override_global(
3795 AgentSettings {
3796 model_parameters: vec![LanguageModelParameters {
3797 provider: Some(model.provider_id().0.to_string().into()),
3798 model: None,
3799 temperature: Some(0.66),
3800 }],
3801 ..AgentSettings::get_global(cx).clone()
3802 },
3803 cx,
3804 );
3805 });
3806
3807 let request = thread.update(cx, |thread, cx| {
3808 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3809 });
3810 assert_eq!(request.temperature, Some(0.66));
3811
3812 // Same model name, different provider
3813 cx.update(|cx| {
3814 AgentSettings::override_global(
3815 AgentSettings {
3816 model_parameters: vec![LanguageModelParameters {
3817 provider: Some("anthropic".into()),
3818 model: Some(model.id().0),
3819 temperature: Some(0.66),
3820 }],
3821 ..AgentSettings::get_global(cx).clone()
3822 },
3823 cx,
3824 );
3825 });
3826
3827 let request = thread.update(cx, |thread, cx| {
3828 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3829 });
3830 assert_eq!(request.temperature, None);
3831 }
3832
3833 #[gpui::test]
3834 async fn test_thread_summary(cx: &mut TestAppContext) {
3835 let fs = init_test_settings(cx);
3836
3837 let project = create_test_project(&fs, cx, json!({})).await;
3838
3839 let (_, _thread_store, thread, _context_store, model) =
3840 setup_test_environment(cx, project.clone()).await;
3841
3842 // Initial state should be pending
3843 thread.read_with(cx, |thread, _| {
3844 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3845 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3846 });
3847
3848 // Manually setting the summary should not be allowed in this state
3849 thread.update(cx, |thread, cx| {
3850 thread.set_summary("This should not work", cx);
3851 });
3852
3853 thread.read_with(cx, |thread, _| {
3854 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3855 });
3856
3857 // Send a message
3858 thread.update(cx, |thread, cx| {
3859 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3860 thread.send_to_model(
3861 model.clone(),
3862 CompletionIntent::ThreadSummarization,
3863 None,
3864 cx,
3865 );
3866 });
3867
3868 let fake_model = model.as_fake();
3869 simulate_successful_response(fake_model, cx);
3870
3871 // Should start generating summary when there are >= 2 messages
3872 thread.read_with(cx, |thread, _| {
3873 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3874 });
3875
3876 // Should not be able to set the summary while generating
3877 thread.update(cx, |thread, cx| {
3878 thread.set_summary("This should not work either", cx);
3879 });
3880
3881 thread.read_with(cx, |thread, _| {
3882 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3883 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3884 });
3885
3886 cx.run_until_parked();
3887 fake_model.send_last_completion_stream_text_chunk("Brief");
3888 fake_model.send_last_completion_stream_text_chunk(" Introduction");
3889 fake_model.end_last_completion_stream();
3890 cx.run_until_parked();
3891
3892 // Summary should be set
3893 thread.read_with(cx, |thread, _| {
3894 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3895 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3896 });
3897
3898 // Now we should be able to set a summary
3899 thread.update(cx, |thread, cx| {
3900 thread.set_summary("Brief Intro", cx);
3901 });
3902
3903 thread.read_with(cx, |thread, _| {
3904 assert_eq!(thread.summary().or_default(), "Brief Intro");
3905 });
3906
3907 // Test setting an empty summary (should default to DEFAULT)
3908 thread.update(cx, |thread, cx| {
3909 thread.set_summary("", cx);
3910 });
3911
3912 thread.read_with(cx, |thread, _| {
3913 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3914 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3915 });
3916 }
3917
3918 #[gpui::test]
3919 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3920 let fs = init_test_settings(cx);
3921
3922 let project = create_test_project(&fs, cx, json!({})).await;
3923
3924 let (_, _thread_store, thread, _context_store, model) =
3925 setup_test_environment(cx, project.clone()).await;
3926
3927 test_summarize_error(&model, &thread, cx);
3928
3929 // Now we should be able to set a summary
3930 thread.update(cx, |thread, cx| {
3931 thread.set_summary("Brief Intro", cx);
3932 });
3933
3934 thread.read_with(cx, |thread, _| {
3935 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3936 assert_eq!(thread.summary().or_default(), "Brief Intro");
3937 });
3938 }
3939
3940 #[gpui::test]
3941 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3942 let fs = init_test_settings(cx);
3943
3944 let project = create_test_project(&fs, cx, json!({})).await;
3945
3946 let (_, _thread_store, thread, _context_store, model) =
3947 setup_test_environment(cx, project.clone()).await;
3948
3949 test_summarize_error(&model, &thread, cx);
3950
3951 // Sending another message should not trigger another summarize request
3952 thread.update(cx, |thread, cx| {
3953 thread.insert_user_message(
3954 "How are you?",
3955 ContextLoadResult::default(),
3956 None,
3957 vec![],
3958 cx,
3959 );
3960 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3961 });
3962
3963 let fake_model = model.as_fake();
3964 simulate_successful_response(fake_model, cx);
3965
3966 thread.read_with(cx, |thread, _| {
3967 // State is still Error, not Generating
3968 assert!(matches!(thread.summary(), ThreadSummary::Error));
3969 });
3970
3971 // But the summarize request can be invoked manually
3972 thread.update(cx, |thread, cx| {
3973 thread.summarize(cx);
3974 });
3975
3976 thread.read_with(cx, |thread, _| {
3977 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3978 });
3979
3980 cx.run_until_parked();
3981 fake_model.send_last_completion_stream_text_chunk("A successful summary");
3982 fake_model.end_last_completion_stream();
3983 cx.run_until_parked();
3984
3985 thread.read_with(cx, |thread, _| {
3986 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3987 assert_eq!(thread.summary().or_default(), "A successful summary");
3988 });
3989 }
3990
3991 // Helper to create a model that returns errors
3992 enum TestError {
3993 Overloaded,
3994 InternalServerError,
3995 }
3996
3997 struct ErrorInjector {
3998 inner: Arc<FakeLanguageModel>,
3999 error_type: TestError,
4000 }
4001
4002 impl ErrorInjector {
4003 fn new(error_type: TestError) -> Self {
4004 Self {
4005 inner: Arc::new(FakeLanguageModel::default()),
4006 error_type,
4007 }
4008 }
4009 }
4010
4011 impl LanguageModel for ErrorInjector {
4012 fn id(&self) -> LanguageModelId {
4013 self.inner.id()
4014 }
4015
4016 fn name(&self) -> LanguageModelName {
4017 self.inner.name()
4018 }
4019
4020 fn provider_id(&self) -> LanguageModelProviderId {
4021 self.inner.provider_id()
4022 }
4023
4024 fn provider_name(&self) -> LanguageModelProviderName {
4025 self.inner.provider_name()
4026 }
4027
4028 fn supports_tools(&self) -> bool {
4029 self.inner.supports_tools()
4030 }
4031
4032 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4033 self.inner.supports_tool_choice(choice)
4034 }
4035
4036 fn supports_images(&self) -> bool {
4037 self.inner.supports_images()
4038 }
4039
4040 fn telemetry_id(&self) -> String {
4041 self.inner.telemetry_id()
4042 }
4043
4044 fn max_token_count(&self) -> u64 {
4045 self.inner.max_token_count()
4046 }
4047
4048 fn count_tokens(
4049 &self,
4050 request: LanguageModelRequest,
4051 cx: &App,
4052 ) -> BoxFuture<'static, Result<u64>> {
4053 self.inner.count_tokens(request, cx)
4054 }
4055
4056 fn stream_completion(
4057 &self,
4058 _request: LanguageModelRequest,
4059 _cx: &AsyncApp,
4060 ) -> BoxFuture<
4061 'static,
4062 Result<
4063 BoxStream<
4064 'static,
4065 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4066 >,
4067 LanguageModelCompletionError,
4068 >,
4069 > {
4070 let error = match self.error_type {
4071 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4072 provider: self.provider_name(),
4073 retry_after: None,
4074 },
4075 TestError::InternalServerError => {
4076 LanguageModelCompletionError::ApiInternalServerError {
4077 provider: self.provider_name(),
4078 message: "I'm a teapot orbiting the sun".to_string(),
4079 }
4080 }
4081 };
4082 async move {
4083 let stream = futures::stream::once(async move { Err(error) });
4084 Ok(stream.boxed())
4085 }
4086 .boxed()
4087 }
4088
4089 fn as_fake(&self) -> &FakeLanguageModel {
4090 &self.inner
4091 }
4092 }
4093
4094 #[gpui::test]
4095 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4096 let fs = init_test_settings(cx);
4097
4098 let project = create_test_project(&fs, cx, json!({})).await;
4099 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4100
4101 // Enable Burn Mode to allow retries
4102 thread.update(cx, |thread, _| {
4103 thread.set_completion_mode(CompletionMode::Burn);
4104 });
4105
4106 // Create model that returns overloaded error
4107 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4108
4109 // Insert a user message
4110 thread.update(cx, |thread, cx| {
4111 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4112 });
4113
4114 // Start completion
4115 thread.update(cx, |thread, cx| {
4116 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4117 });
4118
4119 cx.run_until_parked();
4120
4121 thread.read_with(cx, |thread, _| {
4122 assert!(thread.retry_state.is_some(), "Should have retry state");
4123 let retry_state = thread.retry_state.as_ref().unwrap();
4124 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4125 assert_eq!(
4126 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4127 "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4128 );
4129 });
4130
4131 // Check that a retry message was added
4132 thread.read_with(cx, |thread, _| {
4133 let mut messages = thread.messages();
4134 assert!(
4135 messages.any(|msg| {
4136 msg.role == Role::System
4137 && msg.ui_only
4138 && msg.segments.iter().any(|seg| {
4139 if let MessageSegment::Text(text) = seg {
4140 text.contains("overloaded")
4141 && text
4142 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4143 } else {
4144 false
4145 }
4146 })
4147 }),
4148 "Should have added a system retry message"
4149 );
4150 });
4151
4152 let retry_count = thread.update(cx, |thread, _| {
4153 thread
4154 .messages
4155 .iter()
4156 .filter(|m| {
4157 m.ui_only
4158 && m.segments.iter().any(|s| {
4159 if let MessageSegment::Text(text) = s {
4160 text.contains("Retrying") && text.contains("seconds")
4161 } else {
4162 false
4163 }
4164 })
4165 })
4166 .count()
4167 });
4168
4169 assert_eq!(retry_count, 1, "Should have one retry message");
4170 }
4171
4172 #[gpui::test]
4173 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4174 let fs = init_test_settings(cx);
4175
4176 let project = create_test_project(&fs, cx, json!({})).await;
4177 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4178
4179 // Enable Burn Mode to allow retries
4180 thread.update(cx, |thread, _| {
4181 thread.set_completion_mode(CompletionMode::Burn);
4182 });
4183
4184 // Create model that returns internal server error
4185 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4186
4187 // Insert a user message
4188 thread.update(cx, |thread, cx| {
4189 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4190 });
4191
4192 // Start completion
4193 thread.update(cx, |thread, cx| {
4194 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4195 });
4196
4197 cx.run_until_parked();
4198
4199 // Check retry state on thread
4200 thread.read_with(cx, |thread, _| {
4201 assert!(thread.retry_state.is_some(), "Should have retry state");
4202 let retry_state = thread.retry_state.as_ref().unwrap();
4203 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4204 assert_eq!(
4205 retry_state.max_attempts, 3,
4206 "Should have correct max attempts"
4207 );
4208 });
4209
4210 // Check that a retry message was added with provider name
4211 thread.read_with(cx, |thread, _| {
4212 let mut messages = thread.messages();
4213 assert!(
4214 messages.any(|msg| {
4215 msg.role == Role::System
4216 && msg.ui_only
4217 && msg.segments.iter().any(|seg| {
4218 if let MessageSegment::Text(text) = seg {
4219 text.contains("internal")
4220 && text.contains("Fake")
4221 && text.contains("Retrying")
4222 && text.contains("attempt 1 of 3")
4223 && text.contains("seconds")
4224 } else {
4225 false
4226 }
4227 })
4228 }),
4229 "Should have added a system retry message with provider name"
4230 );
4231 });
4232
4233 // Count retry messages
4234 let retry_count = thread.update(cx, |thread, _| {
4235 thread
4236 .messages
4237 .iter()
4238 .filter(|m| {
4239 m.ui_only
4240 && m.segments.iter().any(|s| {
4241 if let MessageSegment::Text(text) = s {
4242 text.contains("Retrying") && text.contains("seconds")
4243 } else {
4244 false
4245 }
4246 })
4247 })
4248 .count()
4249 });
4250
4251 assert_eq!(retry_count, 1, "Should have one retry message");
4252 }
4253
4254 #[gpui::test]
4255 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4256 let fs = init_test_settings(cx);
4257
4258 let project = create_test_project(&fs, cx, json!({})).await;
4259 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4260
4261 // Enable Burn Mode to allow retries
4262 thread.update(cx, |thread, _| {
4263 thread.set_completion_mode(CompletionMode::Burn);
4264 });
4265
4266 // Create model that returns internal server error
4267 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4268
4269 // Insert a user message
4270 thread.update(cx, |thread, cx| {
4271 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4272 });
4273
4274 // Track retry events and completion count
4275 // Track completion events
4276 let completion_count = Arc::new(Mutex::new(0));
4277 let completion_count_clone = completion_count.clone();
4278
4279 let _subscription = thread.update(cx, |_, cx| {
4280 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4281 if let ThreadEvent::NewRequest = event {
4282 *completion_count_clone.lock() += 1;
4283 }
4284 })
4285 });
4286
4287 // First attempt
4288 thread.update(cx, |thread, cx| {
4289 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4290 });
4291 cx.run_until_parked();
4292
4293 // Should have scheduled first retry - count retry messages
4294 let retry_count = thread.update(cx, |thread, _| {
4295 thread
4296 .messages
4297 .iter()
4298 .filter(|m| {
4299 m.ui_only
4300 && m.segments.iter().any(|s| {
4301 if let MessageSegment::Text(text) = s {
4302 text.contains("Retrying") && text.contains("seconds")
4303 } else {
4304 false
4305 }
4306 })
4307 })
4308 .count()
4309 });
4310 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4311
4312 // Check retry state
4313 thread.read_with(cx, |thread, _| {
4314 assert!(thread.retry_state.is_some(), "Should have retry state");
4315 let retry_state = thread.retry_state.as_ref().unwrap();
4316 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4317 assert_eq!(
4318 retry_state.max_attempts, 3,
4319 "Internal server errors should retry up to 3 times"
4320 );
4321 });
4322
4323 // Advance clock for first retry
4324 cx.executor().advance_clock(BASE_RETRY_DELAY);
4325 cx.run_until_parked();
4326
4327 // Advance clock for second retry
4328 cx.executor().advance_clock(BASE_RETRY_DELAY);
4329 cx.run_until_parked();
4330
4331 // Advance clock for third retry
4332 cx.executor().advance_clock(BASE_RETRY_DELAY);
4333 cx.run_until_parked();
4334
4335 // Should have completed all retries - count retry messages
4336 let retry_count = thread.update(cx, |thread, _| {
4337 thread
4338 .messages
4339 .iter()
4340 .filter(|m| {
4341 m.ui_only
4342 && m.segments.iter().any(|s| {
4343 if let MessageSegment::Text(text) = s {
4344 text.contains("Retrying") && text.contains("seconds")
4345 } else {
4346 false
4347 }
4348 })
4349 })
4350 .count()
4351 });
4352 assert_eq!(
4353 retry_count, 3,
4354 "Should have 3 retries for internal server errors"
4355 );
4356
4357 // For internal server errors, we retry 3 times and then give up
4358 // Check that retry_state is cleared after all retries
4359 thread.read_with(cx, |thread, _| {
4360 assert!(
4361 thread.retry_state.is_none(),
4362 "Retry state should be cleared after all retries"
4363 );
4364 });
4365
4366 // Verify total attempts (1 initial + 3 retries)
4367 assert_eq!(
4368 *completion_count.lock(),
4369 4,
4370 "Should have attempted once plus 3 retries"
4371 );
4372 }
4373
4374 #[gpui::test]
4375 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4376 let fs = init_test_settings(cx);
4377
4378 let project = create_test_project(&fs, cx, json!({})).await;
4379 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4380
4381 // Enable Burn Mode to allow retries
4382 thread.update(cx, |thread, _| {
4383 thread.set_completion_mode(CompletionMode::Burn);
4384 });
4385
4386 // Create model that returns overloaded error
4387 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4388
4389 // Insert a user message
4390 thread.update(cx, |thread, cx| {
4391 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4392 });
4393
4394 // Track events
4395 let stopped_with_error = Arc::new(Mutex::new(false));
4396 let stopped_with_error_clone = stopped_with_error.clone();
4397
4398 let _subscription = thread.update(cx, |_, cx| {
4399 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4400 if let ThreadEvent::Stopped(Err(_)) = event {
4401 *stopped_with_error_clone.lock() = true;
4402 }
4403 })
4404 });
4405
4406 // Start initial completion
4407 thread.update(cx, |thread, cx| {
4408 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4409 });
4410 cx.run_until_parked();
4411
4412 // Advance through all retries
4413 for _ in 0..MAX_RETRY_ATTEMPTS {
4414 cx.executor().advance_clock(BASE_RETRY_DELAY);
4415 cx.run_until_parked();
4416 }
4417
4418 let retry_count = thread.update(cx, |thread, _| {
4419 thread
4420 .messages
4421 .iter()
4422 .filter(|m| {
4423 m.ui_only
4424 && m.segments.iter().any(|s| {
4425 if let MessageSegment::Text(text) = s {
4426 text.contains("Retrying") && text.contains("seconds")
4427 } else {
4428 false
4429 }
4430 })
4431 })
4432 .count()
4433 });
4434
4435 // After max retries, should emit Stopped(Err(...)) event
4436 assert_eq!(
4437 retry_count, MAX_RETRY_ATTEMPTS as usize,
4438 "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4439 );
4440 assert!(
4441 *stopped_with_error.lock(),
4442 "Should emit Stopped(Err(...)) event after max retries exceeded"
4443 );
4444
4445 // Retry state should be cleared
4446 thread.read_with(cx, |thread, _| {
4447 assert!(
4448 thread.retry_state.is_none(),
4449 "Retry state should be cleared after max retries"
4450 );
4451
4452 // Verify we have the expected number of retry messages
4453 let retry_messages = thread
4454 .messages
4455 .iter()
4456 .filter(|msg| msg.ui_only && msg.role == Role::System)
4457 .count();
4458 assert_eq!(
4459 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4460 "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4461 );
4462 });
4463 }
4464
4465 #[gpui::test]
4466 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4467 let fs = init_test_settings(cx);
4468
4469 let project = create_test_project(&fs, cx, json!({})).await;
4470 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4471
4472 // Enable Burn Mode to allow retries
4473 thread.update(cx, |thread, _| {
4474 thread.set_completion_mode(CompletionMode::Burn);
4475 });
4476
4477 // We'll use a wrapper to switch behavior after first failure
4478 struct RetryTestModel {
4479 inner: Arc<FakeLanguageModel>,
4480 failed_once: Arc<Mutex<bool>>,
4481 }
4482
4483 impl LanguageModel for RetryTestModel {
4484 fn id(&self) -> LanguageModelId {
4485 self.inner.id()
4486 }
4487
4488 fn name(&self) -> LanguageModelName {
4489 self.inner.name()
4490 }
4491
4492 fn provider_id(&self) -> LanguageModelProviderId {
4493 self.inner.provider_id()
4494 }
4495
4496 fn provider_name(&self) -> LanguageModelProviderName {
4497 self.inner.provider_name()
4498 }
4499
4500 fn supports_tools(&self) -> bool {
4501 self.inner.supports_tools()
4502 }
4503
4504 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4505 self.inner.supports_tool_choice(choice)
4506 }
4507
4508 fn supports_images(&self) -> bool {
4509 self.inner.supports_images()
4510 }
4511
4512 fn telemetry_id(&self) -> String {
4513 self.inner.telemetry_id()
4514 }
4515
4516 fn max_token_count(&self) -> u64 {
4517 self.inner.max_token_count()
4518 }
4519
4520 fn count_tokens(
4521 &self,
4522 request: LanguageModelRequest,
4523 cx: &App,
4524 ) -> BoxFuture<'static, Result<u64>> {
4525 self.inner.count_tokens(request, cx)
4526 }
4527
4528 fn stream_completion(
4529 &self,
4530 request: LanguageModelRequest,
4531 cx: &AsyncApp,
4532 ) -> BoxFuture<
4533 'static,
4534 Result<
4535 BoxStream<
4536 'static,
4537 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4538 >,
4539 LanguageModelCompletionError,
4540 >,
4541 > {
4542 if !*self.failed_once.lock() {
4543 *self.failed_once.lock() = true;
4544 let provider = self.provider_name();
4545 // Return error on first attempt
4546 let stream = futures::stream::once(async move {
4547 Err(LanguageModelCompletionError::ServerOverloaded {
4548 provider,
4549 retry_after: None,
4550 })
4551 });
4552 async move { Ok(stream.boxed()) }.boxed()
4553 } else {
4554 // Succeed on retry
4555 self.inner.stream_completion(request, cx)
4556 }
4557 }
4558
4559 fn as_fake(&self) -> &FakeLanguageModel {
4560 &self.inner
4561 }
4562 }
4563
4564 let model = Arc::new(RetryTestModel {
4565 inner: Arc::new(FakeLanguageModel::default()),
4566 failed_once: Arc::new(Mutex::new(false)),
4567 });
4568
4569 // Insert a user message
4570 thread.update(cx, |thread, cx| {
4571 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4572 });
4573
4574 // Track message deletions
4575 // Track when retry completes successfully
4576 let retry_completed = Arc::new(Mutex::new(false));
4577 let retry_completed_clone = retry_completed.clone();
4578
4579 let _subscription = thread.update(cx, |_, cx| {
4580 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4581 if let ThreadEvent::StreamedCompletion = event {
4582 *retry_completed_clone.lock() = true;
4583 }
4584 })
4585 });
4586
4587 // Start completion
4588 thread.update(cx, |thread, cx| {
4589 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4590 });
4591 cx.run_until_parked();
4592
4593 // Get the retry message ID
4594 let retry_message_id = thread.read_with(cx, |thread, _| {
4595 thread
4596 .messages()
4597 .find(|msg| msg.role == Role::System && msg.ui_only)
4598 .map(|msg| msg.id)
4599 .expect("Should have a retry message")
4600 });
4601
4602 // Wait for retry
4603 cx.executor().advance_clock(BASE_RETRY_DELAY);
4604 cx.run_until_parked();
4605
4606 // Stream some successful content
4607 let fake_model = model.as_fake();
4608 // After the retry, there should be a new pending completion
4609 let pending = fake_model.pending_completions();
4610 assert!(
4611 !pending.is_empty(),
4612 "Should have a pending completion after retry"
4613 );
4614 fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
4615 fake_model.end_completion_stream(&pending[0]);
4616 cx.run_until_parked();
4617
4618 // Check that the retry completed successfully
4619 assert!(
4620 *retry_completed.lock(),
4621 "Retry should have completed successfully"
4622 );
4623
4624 // Retry message should still exist but be marked as ui_only
4625 thread.read_with(cx, |thread, _| {
4626 let retry_msg = thread
4627 .message(retry_message_id)
4628 .expect("Retry message should still exist");
4629 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4630 assert_eq!(
4631 retry_msg.role,
4632 Role::System,
4633 "Retry message should have System role"
4634 );
4635 });
4636 }
4637
4638 #[gpui::test]
4639 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4640 let fs = init_test_settings(cx);
4641
4642 let project = create_test_project(&fs, cx, json!({})).await;
4643 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4644
4645 // Enable Burn Mode to allow retries
4646 thread.update(cx, |thread, _| {
4647 thread.set_completion_mode(CompletionMode::Burn);
4648 });
4649
4650 // Create a model that fails once then succeeds
4651 struct FailOnceModel {
4652 inner: Arc<FakeLanguageModel>,
4653 failed_once: Arc<Mutex<bool>>,
4654 }
4655
4656 impl LanguageModel for FailOnceModel {
4657 fn id(&self) -> LanguageModelId {
4658 self.inner.id()
4659 }
4660
4661 fn name(&self) -> LanguageModelName {
4662 self.inner.name()
4663 }
4664
4665 fn provider_id(&self) -> LanguageModelProviderId {
4666 self.inner.provider_id()
4667 }
4668
4669 fn provider_name(&self) -> LanguageModelProviderName {
4670 self.inner.provider_name()
4671 }
4672
4673 fn supports_tools(&self) -> bool {
4674 self.inner.supports_tools()
4675 }
4676
4677 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4678 self.inner.supports_tool_choice(choice)
4679 }
4680
4681 fn supports_images(&self) -> bool {
4682 self.inner.supports_images()
4683 }
4684
4685 fn telemetry_id(&self) -> String {
4686 self.inner.telemetry_id()
4687 }
4688
4689 fn max_token_count(&self) -> u64 {
4690 self.inner.max_token_count()
4691 }
4692
4693 fn count_tokens(
4694 &self,
4695 request: LanguageModelRequest,
4696 cx: &App,
4697 ) -> BoxFuture<'static, Result<u64>> {
4698 self.inner.count_tokens(request, cx)
4699 }
4700
4701 fn stream_completion(
4702 &self,
4703 request: LanguageModelRequest,
4704 cx: &AsyncApp,
4705 ) -> BoxFuture<
4706 'static,
4707 Result<
4708 BoxStream<
4709 'static,
4710 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4711 >,
4712 LanguageModelCompletionError,
4713 >,
4714 > {
4715 if !*self.failed_once.lock() {
4716 *self.failed_once.lock() = true;
4717 let provider = self.provider_name();
4718 // Return error on first attempt
4719 let stream = futures::stream::once(async move {
4720 Err(LanguageModelCompletionError::ServerOverloaded {
4721 provider,
4722 retry_after: None,
4723 })
4724 });
4725 async move { Ok(stream.boxed()) }.boxed()
4726 } else {
4727 // Succeed on retry
4728 self.inner.stream_completion(request, cx)
4729 }
4730 }
4731 }
4732
4733 let fail_once_model = Arc::new(FailOnceModel {
4734 inner: Arc::new(FakeLanguageModel::default()),
4735 failed_once: Arc::new(Mutex::new(false)),
4736 });
4737
4738 // Insert a user message
4739 thread.update(cx, |thread, cx| {
4740 thread.insert_user_message(
4741 "Test message",
4742 ContextLoadResult::default(),
4743 None,
4744 vec![],
4745 cx,
4746 );
4747 });
4748
4749 // Start completion with fail-once model
4750 thread.update(cx, |thread, cx| {
4751 thread.send_to_model(
4752 fail_once_model.clone(),
4753 CompletionIntent::UserPrompt,
4754 None,
4755 cx,
4756 );
4757 });
4758
4759 cx.run_until_parked();
4760
4761 // Verify retry state exists after first failure
4762 thread.read_with(cx, |thread, _| {
4763 assert!(
4764 thread.retry_state.is_some(),
4765 "Should have retry state after failure"
4766 );
4767 });
4768
4769 // Wait for retry delay
4770 cx.executor().advance_clock(BASE_RETRY_DELAY);
4771 cx.run_until_parked();
4772
4773 // The retry should now use our FailOnceModel which should succeed
4774 // We need to help the FakeLanguageModel complete the stream
4775 let inner_fake = fail_once_model.inner.clone();
4776
4777 // Wait a bit for the retry to start
4778 cx.run_until_parked();
4779
4780 // Check for pending completions and complete them
4781 if let Some(pending) = inner_fake.pending_completions().first() {
4782 inner_fake.send_completion_stream_text_chunk(pending, "Success!");
4783 inner_fake.end_completion_stream(pending);
4784 }
4785 cx.run_until_parked();
4786
4787 thread.read_with(cx, |thread, _| {
4788 assert!(
4789 thread.retry_state.is_none(),
4790 "Retry state should be cleared after successful completion"
4791 );
4792
4793 let has_assistant_message = thread
4794 .messages
4795 .iter()
4796 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4797 assert!(
4798 has_assistant_message,
4799 "Should have an assistant message after successful retry"
4800 );
4801 });
4802 }
4803
4804 #[gpui::test]
4805 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4806 let fs = init_test_settings(cx);
4807
4808 let project = create_test_project(&fs, cx, json!({})).await;
4809 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4810
4811 // Enable Burn Mode to allow retries
4812 thread.update(cx, |thread, _| {
4813 thread.set_completion_mode(CompletionMode::Burn);
4814 });
4815
4816 // Create a model that returns rate limit error with retry_after
4817 struct RateLimitModel {
4818 inner: Arc<FakeLanguageModel>,
4819 }
4820
4821 impl LanguageModel for RateLimitModel {
4822 fn id(&self) -> LanguageModelId {
4823 self.inner.id()
4824 }
4825
4826 fn name(&self) -> LanguageModelName {
4827 self.inner.name()
4828 }
4829
4830 fn provider_id(&self) -> LanguageModelProviderId {
4831 self.inner.provider_id()
4832 }
4833
4834 fn provider_name(&self) -> LanguageModelProviderName {
4835 self.inner.provider_name()
4836 }
4837
4838 fn supports_tools(&self) -> bool {
4839 self.inner.supports_tools()
4840 }
4841
4842 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4843 self.inner.supports_tool_choice(choice)
4844 }
4845
4846 fn supports_images(&self) -> bool {
4847 self.inner.supports_images()
4848 }
4849
4850 fn telemetry_id(&self) -> String {
4851 self.inner.telemetry_id()
4852 }
4853
4854 fn max_token_count(&self) -> u64 {
4855 self.inner.max_token_count()
4856 }
4857
4858 fn count_tokens(
4859 &self,
4860 request: LanguageModelRequest,
4861 cx: &App,
4862 ) -> BoxFuture<'static, Result<u64>> {
4863 self.inner.count_tokens(request, cx)
4864 }
4865
4866 fn stream_completion(
4867 &self,
4868 _request: LanguageModelRequest,
4869 _cx: &AsyncApp,
4870 ) -> BoxFuture<
4871 'static,
4872 Result<
4873 BoxStream<
4874 'static,
4875 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4876 >,
4877 LanguageModelCompletionError,
4878 >,
4879 > {
4880 let provider = self.provider_name();
4881 async move {
4882 let stream = futures::stream::once(async move {
4883 Err(LanguageModelCompletionError::RateLimitExceeded {
4884 provider,
4885 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4886 })
4887 });
4888 Ok(stream.boxed())
4889 }
4890 .boxed()
4891 }
4892
4893 fn as_fake(&self) -> &FakeLanguageModel {
4894 &self.inner
4895 }
4896 }
4897
4898 let model = Arc::new(RateLimitModel {
4899 inner: Arc::new(FakeLanguageModel::default()),
4900 });
4901
4902 // Insert a user message
4903 thread.update(cx, |thread, cx| {
4904 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4905 });
4906
4907 // Start completion
4908 thread.update(cx, |thread, cx| {
4909 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4910 });
4911
4912 cx.run_until_parked();
4913
4914 let retry_count = thread.update(cx, |thread, _| {
4915 thread
4916 .messages
4917 .iter()
4918 .filter(|m| {
4919 m.ui_only
4920 && m.segments.iter().any(|s| {
4921 if let MessageSegment::Text(text) = s {
4922 text.contains("rate limit exceeded")
4923 } else {
4924 false
4925 }
4926 })
4927 })
4928 .count()
4929 });
4930 assert_eq!(retry_count, 1, "Should have scheduled one retry");
4931
4932 thread.read_with(cx, |thread, _| {
4933 assert!(
4934 thread.retry_state.is_some(),
4935 "Rate limit errors should set retry_state"
4936 );
4937 if let Some(retry_state) = &thread.retry_state {
4938 assert_eq!(
4939 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4940 "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4941 );
4942 }
4943 });
4944
4945 // Verify we have one retry message
4946 thread.read_with(cx, |thread, _| {
4947 let retry_messages = thread
4948 .messages
4949 .iter()
4950 .filter(|msg| {
4951 msg.ui_only
4952 && msg.segments.iter().any(|seg| {
4953 if let MessageSegment::Text(text) = seg {
4954 text.contains("rate limit exceeded")
4955 } else {
4956 false
4957 }
4958 })
4959 })
4960 .count();
4961 assert_eq!(
4962 retry_messages, 1,
4963 "Should have one rate limit retry message"
4964 );
4965 });
4966
4967 // Check that retry message doesn't include attempt count
4968 thread.read_with(cx, |thread, _| {
4969 let retry_message = thread
4970 .messages
4971 .iter()
4972 .find(|msg| msg.role == Role::System && msg.ui_only)
4973 .expect("Should have a retry message");
4974
4975 // Check that the message contains attempt count since we use retry_state
4976 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
4977 assert!(
4978 text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
4979 "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
4980 );
4981 assert!(
4982 text.contains("Retrying"),
4983 "Rate limit retry message should contain retry text"
4984 );
4985 }
4986 });
4987 }
4988
4989 #[gpui::test]
4990 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
4991 let fs = init_test_settings(cx);
4992
4993 let project = create_test_project(&fs, cx, json!({})).await;
4994 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
4995
4996 // Insert a regular user message
4997 thread.update(cx, |thread, cx| {
4998 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4999 });
5000
5001 // Insert a UI-only message (like our retry notifications)
5002 thread.update(cx, |thread, cx| {
5003 let id = thread.next_message_id.post_inc();
5004 thread.messages.push(Message {
5005 id,
5006 role: Role::System,
5007 segments: vec![MessageSegment::Text(
5008 "This is a UI-only message that should not be sent to the model".to_string(),
5009 )],
5010 loaded_context: LoadedContext::default(),
5011 creases: Vec::new(),
5012 is_hidden: true,
5013 ui_only: true,
5014 });
5015 cx.emit(ThreadEvent::MessageAdded(id));
5016 });
5017
5018 // Insert another regular message
5019 thread.update(cx, |thread, cx| {
5020 thread.insert_user_message(
5021 "How are you?",
5022 ContextLoadResult::default(),
5023 None,
5024 vec![],
5025 cx,
5026 );
5027 });
5028
5029 // Generate the completion request
5030 let request = thread.update(cx, |thread, cx| {
5031 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5032 });
5033
5034 // Verify that the request only contains non-UI-only messages
5035 // Should have system prompt + 2 user messages, but not the UI-only message
5036 let user_messages: Vec<_> = request
5037 .messages
5038 .iter()
5039 .filter(|msg| msg.role == Role::User)
5040 .collect();
5041 assert_eq!(
5042 user_messages.len(),
5043 2,
5044 "Should have exactly 2 user messages"
5045 );
5046
5047 // Verify the UI-only content is not present anywhere in the request
5048 let request_text = request
5049 .messages
5050 .iter()
5051 .flat_map(|msg| &msg.content)
5052 .filter_map(|content| match content {
5053 MessageContent::Text(text) => Some(text.as_str()),
5054 _ => None,
5055 })
5056 .collect::<String>();
5057
5058 assert!(
5059 !request_text.contains("UI-only message"),
5060 "UI-only message content should not be in the request"
5061 );
5062
5063 // Verify the thread still has all 3 messages (including UI-only)
5064 thread.read_with(cx, |thread, _| {
5065 assert_eq!(
5066 thread.messages().count(),
5067 3,
5068 "Thread should have 3 messages"
5069 );
5070 assert_eq!(
5071 thread.messages().filter(|m| m.ui_only).count(),
5072 1,
5073 "Thread should have 1 UI-only message"
5074 );
5075 });
5076
5077 // Verify that UI-only messages are not serialized
5078 let serialized = thread
5079 .update(cx, |thread, cx| thread.serialize(cx))
5080 .await
5081 .unwrap();
5082 assert_eq!(
5083 serialized.messages.len(),
5084 2,
5085 "Serialized thread should only have 2 messages (no UI-only)"
5086 );
5087 }
5088
5089 #[gpui::test]
5090 async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
5091 let fs = init_test_settings(cx);
5092
5093 let project = create_test_project(&fs, cx, json!({})).await;
5094 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5095
5096 // Ensure we're in Normal mode (not Burn mode)
5097 thread.update(cx, |thread, _| {
5098 thread.set_completion_mode(CompletionMode::Normal);
5099 });
5100
5101 // Track error events
5102 let error_events = Arc::new(Mutex::new(Vec::new()));
5103 let error_events_clone = error_events.clone();
5104
5105 let _subscription = thread.update(cx, |_, cx| {
5106 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
5107 if let ThreadEvent::ShowError(error) = event {
5108 error_events_clone.lock().push(error.clone());
5109 }
5110 })
5111 });
5112
5113 // Create model that returns overloaded error
5114 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5115
5116 // Insert a user message
5117 thread.update(cx, |thread, cx| {
5118 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5119 });
5120
5121 // Start completion
5122 thread.update(cx, |thread, cx| {
5123 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5124 });
5125
5126 cx.run_until_parked();
5127
5128 // Verify no retry state was created
5129 thread.read_with(cx, |thread, _| {
5130 assert!(
5131 thread.retry_state.is_none(),
5132 "Should not have retry state in Normal mode"
5133 );
5134 });
5135
5136 // Check that a retryable error was reported
5137 let errors = error_events.lock();
5138 assert!(!errors.is_empty(), "Should have received an error event");
5139
5140 if let ThreadError::RetryableError {
5141 message: _,
5142 can_enable_burn_mode,
5143 } = &errors[0]
5144 {
5145 assert!(
5146 *can_enable_burn_mode,
5147 "Error should indicate burn mode can be enabled"
5148 );
5149 } else {
5150 panic!("Expected RetryableError, got {:?}", errors[0]);
5151 }
5152
5153 // Verify the thread is no longer generating
5154 thread.read_with(cx, |thread, _| {
5155 assert!(
5156 !thread.is_generating(),
5157 "Should not be generating after error without retry"
5158 );
5159 });
5160 }
5161
5162 #[gpui::test]
5163 async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
5164 let fs = init_test_settings(cx);
5165
5166 let project = create_test_project(&fs, cx, json!({})).await;
5167 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5168
5169 // Enable Burn Mode to allow retries
5170 thread.update(cx, |thread, _| {
5171 thread.set_completion_mode(CompletionMode::Burn);
5172 });
5173
5174 // Create model that returns overloaded error
5175 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5176
5177 // Insert a user message
5178 thread.update(cx, |thread, cx| {
5179 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5180 });
5181
5182 // Start completion
5183 thread.update(cx, |thread, cx| {
5184 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5185 });
5186
5187 cx.run_until_parked();
5188
5189 // Verify retry was scheduled by checking for retry message
5190 let has_retry_message = thread.read_with(cx, |thread, _| {
5191 thread.messages.iter().any(|m| {
5192 m.ui_only
5193 && m.segments.iter().any(|s| {
5194 if let MessageSegment::Text(text) = s {
5195 text.contains("Retrying") && text.contains("seconds")
5196 } else {
5197 false
5198 }
5199 })
5200 })
5201 });
5202 assert!(has_retry_message, "Should have scheduled a retry");
5203
5204 // Cancel the completion before the retry happens
5205 thread.update(cx, |thread, cx| {
5206 thread.cancel_last_completion(None, cx);
5207 });
5208
5209 cx.run_until_parked();
5210
5211 // The retry should not have happened - no pending completions
5212 let fake_model = model.as_fake();
5213 assert_eq!(
5214 fake_model.pending_completions().len(),
5215 0,
5216 "Should have no pending completions after cancellation"
5217 );
5218
5219 // Verify the retry was canceled by checking retry state
5220 thread.read_with(cx, |thread, _| {
5221 if let Some(retry_state) = &thread.retry_state {
5222 panic!(
5223 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5224 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5225 );
5226 }
5227 });
5228 }
5229
5230 fn test_summarize_error(
5231 model: &Arc<dyn LanguageModel>,
5232 thread: &Entity<Thread>,
5233 cx: &mut TestAppContext,
5234 ) {
5235 thread.update(cx, |thread, cx| {
5236 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5237 thread.send_to_model(
5238 model.clone(),
5239 CompletionIntent::ThreadSummarization,
5240 None,
5241 cx,
5242 );
5243 });
5244
5245 let fake_model = model.as_fake();
5246 simulate_successful_response(fake_model, cx);
5247
5248 thread.read_with(cx, |thread, _| {
5249 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5250 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5251 });
5252
5253 // Simulate summary request ending
5254 cx.run_until_parked();
5255 fake_model.end_last_completion_stream();
5256 cx.run_until_parked();
5257
5258 // State is set to Error and default message
5259 thread.read_with(cx, |thread, _| {
5260 assert!(matches!(thread.summary(), ThreadSummary::Error));
5261 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5262 });
5263 }
5264
5265 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5266 cx.run_until_parked();
5267 fake_model.send_last_completion_stream_text_chunk("Assistant response");
5268 fake_model.end_last_completion_stream();
5269 cx.run_until_parked();
5270 }
5271
5272 fn init_test_settings(cx: &mut TestAppContext) -> Arc<dyn Fs> {
5273 let fs = FakeFs::new(cx.executor());
5274 cx.update(|cx| {
5275 let settings_store = SettingsStore::test(cx);
5276 cx.set_global(settings_store);
5277 language::init(cx);
5278 Project::init_settings(cx);
5279 AgentSettings::register(cx);
5280 prompt_store::init(cx);
5281 thread_store::init(fs.clone(), cx);
5282 workspace::init_settings(cx);
5283 language_model::init_settings(cx);
5284 ThemeSettings::register(cx);
5285 ToolRegistry::default_global(cx);
5286 assistant_tool::init(cx);
5287
5288 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5289 http_client::FakeHttpClient::with_200_response(),
5290 "http://localhost".to_string(),
5291 None,
5292 ));
5293 assistant_tools::init(http_client, cx);
5294 });
5295 fs
5296 }
5297
5298 // Helper to create a test project with test files
5299 async fn create_test_project(
5300 fs: &Arc<dyn Fs>,
5301 cx: &mut TestAppContext,
5302 files: serde_json::Value,
5303 ) -> Entity<Project> {
5304 fs.as_fake().insert_tree(path!("/test"), files).await;
5305 Project::test(fs.clone(), [path!("/test").as_ref()], cx).await
5306 }
5307
5308 async fn setup_test_environment(
5309 cx: &mut TestAppContext,
5310 project: Entity<Project>,
5311 ) -> (
5312 Entity<Workspace>,
5313 Entity<ThreadStore>,
5314 Entity<Thread>,
5315 Entity<ContextStore>,
5316 Arc<dyn LanguageModel>,
5317 ) {
5318 let (workspace, cx) =
5319 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5320
5321 let thread_store = cx
5322 .update(|_, cx| {
5323 ThreadStore::load(
5324 project.clone(),
5325 cx.new(|_| ToolWorkingSet::default()),
5326 None,
5327 Arc::new(PromptBuilder::new(None).unwrap()),
5328 cx,
5329 )
5330 })
5331 .await
5332 .unwrap();
5333
5334 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5335 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5336
5337 let provider = Arc::new(FakeLanguageModelProvider::default());
5338 let model = provider.test_model();
5339 let model: Arc<dyn LanguageModel> = Arc::new(model);
5340
5341 cx.update(|_, cx| {
5342 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5343 registry.set_default_model(
5344 Some(ConfiguredModel {
5345 provider: provider.clone(),
5346 model: model.clone(),
5347 }),
5348 cx,
5349 );
5350 registry.set_thread_summary_model(
5351 Some(ConfiguredModel {
5352 provider,
5353 model: model.clone(),
5354 }),
5355 cx,
5356 );
5357 })
5358 });
5359
5360 (workspace, thread_store, thread, context_store, model)
5361 }
5362
5363 async fn add_file_to_context(
5364 project: &Entity<Project>,
5365 context_store: &Entity<ContextStore>,
5366 path: &str,
5367 cx: &mut TestAppContext,
5368 ) -> Result<Entity<language::Buffer>> {
5369 let buffer_path = project
5370 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5371 .unwrap();
5372
5373 let buffer = project
5374 .update(cx, |project, cx| {
5375 project.open_buffer(buffer_path.clone(), cx)
5376 })
5377 .await
5378 .unwrap();
5379
5380 context_store.update(cx, |context_store, cx| {
5381 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5382 });
5383
5384 Ok(buffer)
5385 }
5386}