1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::{HashMap, HashSet};
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
28 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
29 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
30 TokenUsage,
31};
32use postage::stream::Stream as _;
33use project::{
34 Project,
35 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
36};
37use prompt_store::{ModelContext, PromptBuilder};
38use proto::Plan;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use settings::Settings;
42use std::{
43 io::Write,
44 ops::Range,
45 sync::Arc,
46 time::{Duration, Instant},
47};
48use thiserror::Error;
49use util::{ResultExt as _, post_inc};
50use uuid::Uuid;
51use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
52
53const MAX_RETRY_ATTEMPTS: u8 = 3;
54const BASE_RETRY_DELAY_SECS: u64 = 5;
55
56#[derive(
57 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
58)]
59pub struct ThreadId(Arc<str>);
60
61impl ThreadId {
62 pub fn new() -> Self {
63 Self(Uuid::new_v4().to_string().into())
64 }
65}
66
67impl std::fmt::Display for ThreadId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73impl From<&str> for ThreadId {
74 fn from(value: &str) -> Self {
75 Self(value.into())
76 }
77}
78
79/// The ID of the user prompt that initiated a request.
80///
81/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
82#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
83pub struct PromptId(Arc<str>);
84
85impl PromptId {
86 pub fn new() -> Self {
87 Self(Uuid::new_v4().to_string().into())
88 }
89}
90
91impl std::fmt::Display for PromptId {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{}", self.0)
94 }
95}
96
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
98pub struct MessageId(pub(crate) usize);
99
100impl MessageId {
101 fn post_inc(&mut self) -> Self {
102 Self(post_inc(&mut self.0))
103 }
104
105 pub fn as_usize(&self) -> usize {
106 self.0
107 }
108}
109
110/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
111#[derive(Clone, Debug)]
112pub struct MessageCrease {
113 pub range: Range<usize>,
114 pub icon_path: SharedString,
115 pub label: SharedString,
116 /// None for a deserialized message, Some otherwise.
117 pub context: Option<AgentContextHandle>,
118}
119
120/// A message in a [`Thread`].
121#[derive(Debug, Clone)]
122pub struct Message {
123 pub id: MessageId,
124 pub role: Role,
125 pub segments: Vec<MessageSegment>,
126 pub loaded_context: LoadedContext,
127 pub creases: Vec<MessageCrease>,
128 pub is_hidden: bool,
129 pub ui_only: bool,
130}
131
132impl Message {
133 /// Returns whether the message contains any meaningful text that should be displayed
134 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
135 pub fn should_display_content(&self) -> bool {
136 self.segments.iter().all(|segment| segment.should_display())
137 }
138
139 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
140 if let Some(MessageSegment::Thinking {
141 text: segment,
142 signature: current_signature,
143 }) = self.segments.last_mut()
144 {
145 if let Some(signature) = signature {
146 *current_signature = Some(signature);
147 }
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Thinking {
151 text: text.to_string(),
152 signature,
153 });
154 }
155 }
156
157 pub fn push_redacted_thinking(&mut self, data: String) {
158 self.segments.push(MessageSegment::RedactedThinking(data));
159 }
160
161 pub fn push_text(&mut self, text: &str) {
162 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Text(text.to_string()));
166 }
167 }
168
169 pub fn to_string(&self) -> String {
170 let mut result = String::new();
171
172 if !self.loaded_context.text.is_empty() {
173 result.push_str(&self.loaded_context.text);
174 }
175
176 for segment in &self.segments {
177 match segment {
178 MessageSegment::Text(text) => result.push_str(text),
179 MessageSegment::Thinking { text, .. } => {
180 result.push_str("<think>\n");
181 result.push_str(text);
182 result.push_str("\n</think>");
183 }
184 MessageSegment::RedactedThinking(_) => {}
185 }
186 }
187
188 result
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub enum MessageSegment {
194 Text(String),
195 Thinking {
196 text: String,
197 signature: Option<String>,
198 },
199 RedactedThinking(String),
200}
201
202impl MessageSegment {
203 pub fn should_display(&self) -> bool {
204 match self {
205 Self::Text(text) => text.is_empty(),
206 Self::Thinking { text, .. } => text.is_empty(),
207 Self::RedactedThinking(_) => false,
208 }
209 }
210
211 pub fn text(&self) -> Option<&str> {
212 match self {
213 MessageSegment::Text(text) => Some(text),
214 _ => None,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
220pub struct ProjectSnapshot {
221 pub worktree_snapshots: Vec<WorktreeSnapshot>,
222 pub unsaved_buffer_paths: Vec<String>,
223 pub timestamp: DateTime<Utc>,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227pub struct WorktreeSnapshot {
228 pub worktree_path: String,
229 pub git_state: Option<GitState>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
233pub struct GitState {
234 pub remote_url: Option<String>,
235 pub head_sha: Option<String>,
236 pub current_branch: Option<String>,
237 pub diff: Option<String>,
238}
239
240#[derive(Clone, Debug)]
241pub struct ThreadCheckpoint {
242 message_id: MessageId,
243 git_checkpoint: GitStoreCheckpoint,
244}
245
246#[derive(Copy, Clone, Debug, PartialEq, Eq)]
247pub enum ThreadFeedback {
248 Positive,
249 Negative,
250}
251
252pub enum LastRestoreCheckpoint {
253 Pending {
254 message_id: MessageId,
255 },
256 Error {
257 message_id: MessageId,
258 error: String,
259 },
260}
261
262impl LastRestoreCheckpoint {
263 pub fn message_id(&self) -> MessageId {
264 match self {
265 LastRestoreCheckpoint::Pending { message_id } => *message_id,
266 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
267 }
268 }
269}
270
271#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
272pub enum DetailedSummaryState {
273 #[default]
274 NotGenerated,
275 Generating {
276 message_id: MessageId,
277 },
278 Generated {
279 text: SharedString,
280 message_id: MessageId,
281 },
282}
283
284impl DetailedSummaryState {
285 fn text(&self) -> Option<SharedString> {
286 if let Self::Generated { text, .. } = self {
287 Some(text.clone())
288 } else {
289 None
290 }
291 }
292}
293
294#[derive(Default, Debug)]
295pub struct TotalTokenUsage {
296 pub total: u64,
297 pub max: u64,
298}
299
300impl TotalTokenUsage {
301 pub fn ratio(&self) -> TokenUsageRatio {
302 #[cfg(debug_assertions)]
303 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
304 .unwrap_or("0.8".to_string())
305 .parse()
306 .unwrap();
307 #[cfg(not(debug_assertions))]
308 let warning_threshold: f32 = 0.8;
309
310 // When the maximum is unknown because there is no selected model,
311 // avoid showing the token limit warning.
312 if self.max == 0 {
313 TokenUsageRatio::Normal
314 } else if self.total >= self.max {
315 TokenUsageRatio::Exceeded
316 } else if self.total as f32 / self.max as f32 >= warning_threshold {
317 TokenUsageRatio::Warning
318 } else {
319 TokenUsageRatio::Normal
320 }
321 }
322
323 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
324 TotalTokenUsage {
325 total: self.total + tokens,
326 max: self.max,
327 }
328 }
329}
330
331#[derive(Debug, Default, PartialEq, Eq)]
332pub enum TokenUsageRatio {
333 #[default]
334 Normal,
335 Warning,
336 Exceeded,
337}
338
339#[derive(Debug, Clone, Copy)]
340pub enum QueueState {
341 Sending,
342 Queued { position: usize },
343 Started,
344}
345
346/// A thread of conversation with the LLM.
347pub struct Thread {
348 id: ThreadId,
349 updated_at: DateTime<Utc>,
350 summary: ThreadSummary,
351 pending_summary: Task<Option<()>>,
352 detailed_summary_task: Task<Option<()>>,
353 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
354 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
355 completion_mode: agent_settings::CompletionMode,
356 messages: Vec<Message>,
357 next_message_id: MessageId,
358 last_prompt_id: PromptId,
359 project_context: SharedProjectContext,
360 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
361 completion_count: usize,
362 pending_completions: Vec<PendingCompletion>,
363 project: Entity<Project>,
364 prompt_builder: Arc<PromptBuilder>,
365 tools: Entity<ToolWorkingSet>,
366 tool_use: ToolUseState,
367 action_log: Entity<ActionLog>,
368 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
369 pending_checkpoint: Option<ThreadCheckpoint>,
370 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
371 request_token_usage: Vec<TokenUsage>,
372 cumulative_token_usage: TokenUsage,
373 exceeded_window_error: Option<ExceededWindowError>,
374 tool_use_limit_reached: bool,
375 feedback: Option<ThreadFeedback>,
376 retry_state: Option<RetryState>,
377 message_feedback: HashMap<MessageId, ThreadFeedback>,
378 last_auto_capture_at: Option<Instant>,
379 last_received_chunk_at: Option<Instant>,
380 request_callback: Option<
381 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
382 >,
383 remaining_turns: u32,
384 configured_model: Option<ConfiguredModel>,
385 profile: AgentProfile,
386}
387
388#[derive(Clone, Debug)]
389struct RetryState {
390 attempt: u8,
391 max_attempts: u8,
392 intent: CompletionIntent,
393}
394
395#[derive(Clone, Debug, PartialEq, Eq)]
396pub enum ThreadSummary {
397 Pending,
398 Generating,
399 Ready(SharedString),
400 Error,
401}
402
403impl ThreadSummary {
404 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
405
406 pub fn or_default(&self) -> SharedString {
407 self.unwrap_or(Self::DEFAULT)
408 }
409
410 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
411 self.ready().unwrap_or_else(|| message.into())
412 }
413
414 pub fn ready(&self) -> Option<SharedString> {
415 match self {
416 ThreadSummary::Ready(summary) => Some(summary.clone()),
417 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
418 }
419 }
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
423pub struct ExceededWindowError {
424 /// Model used when last message exceeded context window
425 model_id: LanguageModelId,
426 /// Token count including last message
427 token_count: u64,
428}
429
430impl Thread {
431 pub fn new(
432 project: Entity<Project>,
433 tools: Entity<ToolWorkingSet>,
434 prompt_builder: Arc<PromptBuilder>,
435 system_prompt: SharedProjectContext,
436 cx: &mut Context<Self>,
437 ) -> Self {
438 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
439 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
440 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
441
442 Self {
443 id: ThreadId::new(),
444 updated_at: Utc::now(),
445 summary: ThreadSummary::Pending,
446 pending_summary: Task::ready(None),
447 detailed_summary_task: Task::ready(None),
448 detailed_summary_tx,
449 detailed_summary_rx,
450 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
451 messages: Vec::new(),
452 next_message_id: MessageId(0),
453 last_prompt_id: PromptId::new(),
454 project_context: system_prompt,
455 checkpoints_by_message: HashMap::default(),
456 completion_count: 0,
457 pending_completions: Vec::new(),
458 project: project.clone(),
459 prompt_builder,
460 tools: tools.clone(),
461 last_restore_checkpoint: None,
462 pending_checkpoint: None,
463 tool_use: ToolUseState::new(tools.clone()),
464 action_log: cx.new(|_| ActionLog::new(project.clone())),
465 initial_project_snapshot: {
466 let project_snapshot = Self::project_snapshot(project, cx);
467 cx.foreground_executor()
468 .spawn(async move { Some(project_snapshot.await) })
469 .shared()
470 },
471 request_token_usage: Vec::new(),
472 cumulative_token_usage: TokenUsage::default(),
473 exceeded_window_error: None,
474 tool_use_limit_reached: false,
475 feedback: None,
476 retry_state: None,
477 message_feedback: HashMap::default(),
478 last_auto_capture_at: None,
479 last_received_chunk_at: None,
480 request_callback: None,
481 remaining_turns: u32::MAX,
482 configured_model,
483 profile: AgentProfile::new(profile_id, tools),
484 }
485 }
486
487 pub fn deserialize(
488 id: ThreadId,
489 serialized: SerializedThread,
490 project: Entity<Project>,
491 tools: Entity<ToolWorkingSet>,
492 prompt_builder: Arc<PromptBuilder>,
493 project_context: SharedProjectContext,
494 window: Option<&mut Window>, // None in headless mode
495 cx: &mut Context<Self>,
496 ) -> Self {
497 let next_message_id = MessageId(
498 serialized
499 .messages
500 .last()
501 .map(|message| message.id.0 + 1)
502 .unwrap_or(0),
503 );
504 let tool_use = ToolUseState::from_serialized_messages(
505 tools.clone(),
506 &serialized.messages,
507 project.clone(),
508 window,
509 cx,
510 );
511 let (detailed_summary_tx, detailed_summary_rx) =
512 postage::watch::channel_with(serialized.detailed_summary_state);
513
514 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
515 serialized
516 .model
517 .and_then(|model| {
518 let model = SelectedModel {
519 provider: model.provider.clone().into(),
520 model: model.model.clone().into(),
521 };
522 registry.select_model(&model, cx)
523 })
524 .or_else(|| registry.default_model())
525 });
526
527 let completion_mode = serialized
528 .completion_mode
529 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
530 let profile_id = serialized
531 .profile
532 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
533
534 Self {
535 id,
536 updated_at: serialized.updated_at,
537 summary: ThreadSummary::Ready(serialized.summary),
538 pending_summary: Task::ready(None),
539 detailed_summary_task: Task::ready(None),
540 detailed_summary_tx,
541 detailed_summary_rx,
542 completion_mode,
543 retry_state: None,
544 messages: serialized
545 .messages
546 .into_iter()
547 .map(|message| Message {
548 id: message.id,
549 role: message.role,
550 segments: message
551 .segments
552 .into_iter()
553 .map(|segment| match segment {
554 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
555 SerializedMessageSegment::Thinking { text, signature } => {
556 MessageSegment::Thinking { text, signature }
557 }
558 SerializedMessageSegment::RedactedThinking { data } => {
559 MessageSegment::RedactedThinking(data)
560 }
561 })
562 .collect(),
563 loaded_context: LoadedContext {
564 contexts: Vec::new(),
565 text: message.context,
566 images: Vec::new(),
567 },
568 creases: message
569 .creases
570 .into_iter()
571 .map(|crease| MessageCrease {
572 range: crease.start..crease.end,
573 icon_path: crease.icon_path,
574 label: crease.label,
575 context: None,
576 })
577 .collect(),
578 is_hidden: message.is_hidden,
579 ui_only: false, // UI-only messages are not persisted
580 })
581 .collect(),
582 next_message_id,
583 last_prompt_id: PromptId::new(),
584 project_context,
585 checkpoints_by_message: HashMap::default(),
586 completion_count: 0,
587 pending_completions: Vec::new(),
588 last_restore_checkpoint: None,
589 pending_checkpoint: None,
590 project: project.clone(),
591 prompt_builder,
592 tools: tools.clone(),
593 tool_use,
594 action_log: cx.new(|_| ActionLog::new(project)),
595 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
596 request_token_usage: serialized.request_token_usage,
597 cumulative_token_usage: serialized.cumulative_token_usage,
598 exceeded_window_error: None,
599 tool_use_limit_reached: serialized.tool_use_limit_reached,
600 feedback: None,
601 message_feedback: HashMap::default(),
602 last_auto_capture_at: None,
603 last_received_chunk_at: None,
604 request_callback: None,
605 remaining_turns: u32::MAX,
606 configured_model,
607 profile: AgentProfile::new(profile_id, tools),
608 }
609 }
610
611 pub fn set_request_callback(
612 &mut self,
613 callback: impl 'static
614 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
615 ) {
616 self.request_callback = Some(Box::new(callback));
617 }
618
619 pub fn id(&self) -> &ThreadId {
620 &self.id
621 }
622
623 pub fn profile(&self) -> &AgentProfile {
624 &self.profile
625 }
626
627 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
628 if &id != self.profile.id() {
629 self.profile = AgentProfile::new(id, self.tools.clone());
630 cx.emit(ThreadEvent::ProfileChanged);
631 }
632 }
633
634 pub fn is_empty(&self) -> bool {
635 self.messages.is_empty()
636 }
637
638 pub fn updated_at(&self) -> DateTime<Utc> {
639 self.updated_at
640 }
641
642 pub fn touch_updated_at(&mut self) {
643 self.updated_at = Utc::now();
644 }
645
646 pub fn advance_prompt_id(&mut self) {
647 self.last_prompt_id = PromptId::new();
648 }
649
650 pub fn project_context(&self) -> SharedProjectContext {
651 self.project_context.clone()
652 }
653
654 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
655 if self.configured_model.is_none() {
656 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
657 }
658 self.configured_model.clone()
659 }
660
661 pub fn configured_model(&self) -> Option<ConfiguredModel> {
662 self.configured_model.clone()
663 }
664
665 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
666 self.configured_model = model;
667 cx.notify();
668 }
669
670 pub fn summary(&self) -> &ThreadSummary {
671 &self.summary
672 }
673
674 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
675 let current_summary = match &self.summary {
676 ThreadSummary::Pending | ThreadSummary::Generating => return,
677 ThreadSummary::Ready(summary) => summary,
678 ThreadSummary::Error => &ThreadSummary::DEFAULT,
679 };
680
681 let mut new_summary = new_summary.into();
682
683 if new_summary.is_empty() {
684 new_summary = ThreadSummary::DEFAULT;
685 }
686
687 if current_summary != &new_summary {
688 self.summary = ThreadSummary::Ready(new_summary);
689 cx.emit(ThreadEvent::SummaryChanged);
690 }
691 }
692
693 pub fn completion_mode(&self) -> CompletionMode {
694 self.completion_mode
695 }
696
697 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
698 self.completion_mode = mode;
699 }
700
701 pub fn message(&self, id: MessageId) -> Option<&Message> {
702 let index = self
703 .messages
704 .binary_search_by(|message| message.id.cmp(&id))
705 .ok()?;
706
707 self.messages.get(index)
708 }
709
710 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
711 self.messages.iter()
712 }
713
714 pub fn is_generating(&self) -> bool {
715 !self.pending_completions.is_empty() || !self.all_tools_finished()
716 }
717
718 /// Indicates whether streaming of language model events is stale.
719 /// When `is_generating()` is false, this method returns `None`.
720 pub fn is_generation_stale(&self) -> Option<bool> {
721 const STALE_THRESHOLD: u128 = 250;
722
723 self.last_received_chunk_at
724 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
725 }
726
727 fn received_chunk(&mut self) {
728 self.last_received_chunk_at = Some(Instant::now());
729 }
730
731 pub fn queue_state(&self) -> Option<QueueState> {
732 self.pending_completions
733 .first()
734 .map(|pending_completion| pending_completion.queue_state)
735 }
736
737 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
738 &self.tools
739 }
740
741 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
742 self.tool_use
743 .pending_tool_uses()
744 .into_iter()
745 .find(|tool_use| &tool_use.id == id)
746 }
747
748 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
749 self.tool_use
750 .pending_tool_uses()
751 .into_iter()
752 .filter(|tool_use| tool_use.status.needs_confirmation())
753 }
754
755 pub fn has_pending_tool_uses(&self) -> bool {
756 !self.tool_use.pending_tool_uses().is_empty()
757 }
758
759 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
760 self.checkpoints_by_message.get(&id).cloned()
761 }
762
763 pub fn restore_checkpoint(
764 &mut self,
765 checkpoint: ThreadCheckpoint,
766 cx: &mut Context<Self>,
767 ) -> Task<Result<()>> {
768 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
769 message_id: checkpoint.message_id,
770 });
771 cx.emit(ThreadEvent::CheckpointChanged);
772 cx.notify();
773
774 let git_store = self.project().read(cx).git_store().clone();
775 let restore = git_store.update(cx, |git_store, cx| {
776 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
777 });
778
779 cx.spawn(async move |this, cx| {
780 let result = restore.await;
781 this.update(cx, |this, cx| {
782 if let Err(err) = result.as_ref() {
783 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
784 message_id: checkpoint.message_id,
785 error: err.to_string(),
786 });
787 } else {
788 this.truncate(checkpoint.message_id, cx);
789 this.last_restore_checkpoint = None;
790 }
791 this.pending_checkpoint = None;
792 cx.emit(ThreadEvent::CheckpointChanged);
793 cx.notify();
794 })?;
795 result
796 })
797 }
798
799 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
800 let pending_checkpoint = if self.is_generating() {
801 return;
802 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
803 checkpoint
804 } else {
805 return;
806 };
807
808 self.finalize_checkpoint(pending_checkpoint, cx);
809 }
810
811 fn finalize_checkpoint(
812 &mut self,
813 pending_checkpoint: ThreadCheckpoint,
814 cx: &mut Context<Self>,
815 ) {
816 let git_store = self.project.read(cx).git_store().clone();
817 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
818 cx.spawn(async move |this, cx| match final_checkpoint.await {
819 Ok(final_checkpoint) => {
820 let equal = git_store
821 .update(cx, |store, cx| {
822 store.compare_checkpoints(
823 pending_checkpoint.git_checkpoint.clone(),
824 final_checkpoint.clone(),
825 cx,
826 )
827 })?
828 .await
829 .unwrap_or(false);
830
831 if !equal {
832 this.update(cx, |this, cx| {
833 this.insert_checkpoint(pending_checkpoint, cx)
834 })?;
835 }
836
837 Ok(())
838 }
839 Err(_) => this.update(cx, |this, cx| {
840 this.insert_checkpoint(pending_checkpoint, cx)
841 }),
842 })
843 .detach();
844 }
845
846 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
847 self.checkpoints_by_message
848 .insert(checkpoint.message_id, checkpoint);
849 cx.emit(ThreadEvent::CheckpointChanged);
850 cx.notify();
851 }
852
853 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
854 self.last_restore_checkpoint.as_ref()
855 }
856
857 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
858 let Some(message_ix) = self
859 .messages
860 .iter()
861 .rposition(|message| message.id == message_id)
862 else {
863 return;
864 };
865 for deleted_message in self.messages.drain(message_ix..) {
866 self.checkpoints_by_message.remove(&deleted_message.id);
867 }
868 cx.notify();
869 }
870
871 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
872 self.messages
873 .iter()
874 .find(|message| message.id == id)
875 .into_iter()
876 .flat_map(|message| message.loaded_context.contexts.iter())
877 }
878
879 pub fn is_turn_end(&self, ix: usize) -> bool {
880 if self.messages.is_empty() {
881 return false;
882 }
883
884 if !self.is_generating() && ix == self.messages.len() - 1 {
885 return true;
886 }
887
888 let Some(message) = self.messages.get(ix) else {
889 return false;
890 };
891
892 if message.role != Role::Assistant {
893 return false;
894 }
895
896 self.messages
897 .get(ix + 1)
898 .and_then(|message| {
899 self.message(message.id)
900 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
901 })
902 .unwrap_or(false)
903 }
904
905 pub fn tool_use_limit_reached(&self) -> bool {
906 self.tool_use_limit_reached
907 }
908
909 /// Returns whether all of the tool uses have finished running.
910 pub fn all_tools_finished(&self) -> bool {
911 // If the only pending tool uses left are the ones with errors, then
912 // that means that we've finished running all of the pending tools.
913 self.tool_use
914 .pending_tool_uses()
915 .iter()
916 .all(|pending_tool_use| pending_tool_use.status.is_error())
917 }
918
919 /// Returns whether any pending tool uses may perform edits
920 pub fn has_pending_edit_tool_uses(&self) -> bool {
921 self.tool_use
922 .pending_tool_uses()
923 .iter()
924 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
925 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
926 }
927
928 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
929 self.tool_use.tool_uses_for_message(id, cx)
930 }
931
932 pub fn tool_results_for_message(
933 &self,
934 assistant_message_id: MessageId,
935 ) -> Vec<&LanguageModelToolResult> {
936 self.tool_use.tool_results_for_message(assistant_message_id)
937 }
938
939 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
940 self.tool_use.tool_result(id)
941 }
942
943 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
944 match &self.tool_use.tool_result(id)?.content {
945 LanguageModelToolResultContent::Text(text) => Some(text),
946 LanguageModelToolResultContent::Image(_) => {
947 // TODO: We should display image
948 None
949 }
950 }
951 }
952
953 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
954 self.tool_use.tool_result_card(id).cloned()
955 }
956
957 /// Return tools that are both enabled and supported by the model
958 pub fn available_tools(
959 &self,
960 cx: &App,
961 model: Arc<dyn LanguageModel>,
962 ) -> Vec<LanguageModelRequestTool> {
963 if model.supports_tools() {
964 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
965 .into_iter()
966 .filter_map(|(name, tool)| {
967 // Skip tools that cannot be supported
968 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
969 Some(LanguageModelRequestTool {
970 name,
971 description: tool.description(),
972 input_schema,
973 })
974 })
975 .collect()
976 } else {
977 Vec::default()
978 }
979 }
980
981 pub fn insert_user_message(
982 &mut self,
983 text: impl Into<String>,
984 loaded_context: ContextLoadResult,
985 git_checkpoint: Option<GitStoreCheckpoint>,
986 creases: Vec<MessageCrease>,
987 cx: &mut Context<Self>,
988 ) -> MessageId {
989 if !loaded_context.referenced_buffers.is_empty() {
990 self.action_log.update(cx, |log, cx| {
991 for buffer in loaded_context.referenced_buffers {
992 log.buffer_read(buffer, cx);
993 }
994 });
995 }
996
997 let message_id = self.insert_message(
998 Role::User,
999 vec![MessageSegment::Text(text.into())],
1000 loaded_context.loaded_context,
1001 creases,
1002 false,
1003 cx,
1004 );
1005
1006 if let Some(git_checkpoint) = git_checkpoint {
1007 self.pending_checkpoint = Some(ThreadCheckpoint {
1008 message_id,
1009 git_checkpoint,
1010 });
1011 }
1012
1013 self.auto_capture_telemetry(cx);
1014
1015 message_id
1016 }
1017
1018 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019 let id = self.insert_message(
1020 Role::User,
1021 vec![MessageSegment::Text("Continue where you left off".into())],
1022 LoadedContext::default(),
1023 vec![],
1024 true,
1025 cx,
1026 );
1027 self.pending_checkpoint = None;
1028
1029 id
1030 }
1031
1032 pub fn insert_assistant_message(
1033 &mut self,
1034 segments: Vec<MessageSegment>,
1035 cx: &mut Context<Self>,
1036 ) -> MessageId {
1037 self.insert_message(
1038 Role::Assistant,
1039 segments,
1040 LoadedContext::default(),
1041 Vec::new(),
1042 false,
1043 cx,
1044 )
1045 }
1046
1047 pub fn insert_message(
1048 &mut self,
1049 role: Role,
1050 segments: Vec<MessageSegment>,
1051 loaded_context: LoadedContext,
1052 creases: Vec<MessageCrease>,
1053 is_hidden: bool,
1054 cx: &mut Context<Self>,
1055 ) -> MessageId {
1056 let id = self.next_message_id.post_inc();
1057 self.messages.push(Message {
1058 id,
1059 role,
1060 segments,
1061 loaded_context,
1062 creases,
1063 is_hidden,
1064 ui_only: false,
1065 });
1066 self.touch_updated_at();
1067 cx.emit(ThreadEvent::MessageAdded(id));
1068 id
1069 }
1070
1071 pub fn edit_message(
1072 &mut self,
1073 id: MessageId,
1074 new_role: Role,
1075 new_segments: Vec<MessageSegment>,
1076 creases: Vec<MessageCrease>,
1077 loaded_context: Option<LoadedContext>,
1078 checkpoint: Option<GitStoreCheckpoint>,
1079 cx: &mut Context<Self>,
1080 ) -> bool {
1081 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082 return false;
1083 };
1084 message.role = new_role;
1085 message.segments = new_segments;
1086 message.creases = creases;
1087 if let Some(context) = loaded_context {
1088 message.loaded_context = context;
1089 }
1090 if let Some(git_checkpoint) = checkpoint {
1091 self.checkpoints_by_message.insert(
1092 id,
1093 ThreadCheckpoint {
1094 message_id: id,
1095 git_checkpoint,
1096 },
1097 );
1098 }
1099 self.touch_updated_at();
1100 cx.emit(ThreadEvent::MessageEdited(id));
1101 true
1102 }
1103
1104 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106 return false;
1107 };
1108 self.messages.remove(index);
1109 self.touch_updated_at();
1110 cx.emit(ThreadEvent::MessageDeleted(id));
1111 true
1112 }
1113
1114 /// Returns the representation of this [`Thread`] in a textual form.
1115 ///
1116 /// This is the representation we use when attaching a thread as context to another thread.
1117 pub fn text(&self) -> String {
1118 let mut text = String::new();
1119
1120 for message in &self.messages {
1121 text.push_str(match message.role {
1122 language_model::Role::User => "User:",
1123 language_model::Role::Assistant => "Agent:",
1124 language_model::Role::System => "System:",
1125 });
1126 text.push('\n');
1127
1128 for segment in &message.segments {
1129 match segment {
1130 MessageSegment::Text(content) => text.push_str(content),
1131 MessageSegment::Thinking { text: content, .. } => {
1132 text.push_str(&format!("<think>{}</think>", content))
1133 }
1134 MessageSegment::RedactedThinking(_) => {}
1135 }
1136 }
1137 text.push('\n');
1138 }
1139
1140 text
1141 }
1142
1143 /// Serializes this thread into a format for storage or telemetry.
1144 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145 let initial_project_snapshot = self.initial_project_snapshot.clone();
1146 cx.spawn(async move |this, cx| {
1147 let initial_project_snapshot = initial_project_snapshot.await;
1148 this.read_with(cx, |this, cx| SerializedThread {
1149 version: SerializedThread::VERSION.to_string(),
1150 summary: this.summary().or_default(),
1151 updated_at: this.updated_at(),
1152 messages: this
1153 .messages()
1154 .filter(|message| !message.ui_only)
1155 .map(|message| SerializedMessage {
1156 id: message.id,
1157 role: message.role,
1158 segments: message
1159 .segments
1160 .iter()
1161 .map(|segment| match segment {
1162 MessageSegment::Text(text) => {
1163 SerializedMessageSegment::Text { text: text.clone() }
1164 }
1165 MessageSegment::Thinking { text, signature } => {
1166 SerializedMessageSegment::Thinking {
1167 text: text.clone(),
1168 signature: signature.clone(),
1169 }
1170 }
1171 MessageSegment::RedactedThinking(data) => {
1172 SerializedMessageSegment::RedactedThinking {
1173 data: data.clone(),
1174 }
1175 }
1176 })
1177 .collect(),
1178 tool_uses: this
1179 .tool_uses_for_message(message.id, cx)
1180 .into_iter()
1181 .map(|tool_use| SerializedToolUse {
1182 id: tool_use.id,
1183 name: tool_use.name,
1184 input: tool_use.input,
1185 })
1186 .collect(),
1187 tool_results: this
1188 .tool_results_for_message(message.id)
1189 .into_iter()
1190 .map(|tool_result| SerializedToolResult {
1191 tool_use_id: tool_result.tool_use_id.clone(),
1192 is_error: tool_result.is_error,
1193 content: tool_result.content.clone(),
1194 output: tool_result.output.clone(),
1195 })
1196 .collect(),
1197 context: message.loaded_context.text.clone(),
1198 creases: message
1199 .creases
1200 .iter()
1201 .map(|crease| SerializedCrease {
1202 start: crease.range.start,
1203 end: crease.range.end,
1204 icon_path: crease.icon_path.clone(),
1205 label: crease.label.clone(),
1206 })
1207 .collect(),
1208 is_hidden: message.is_hidden,
1209 })
1210 .collect(),
1211 initial_project_snapshot,
1212 cumulative_token_usage: this.cumulative_token_usage,
1213 request_token_usage: this.request_token_usage.clone(),
1214 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215 exceeded_window_error: this.exceeded_window_error.clone(),
1216 model: this
1217 .configured_model
1218 .as_ref()
1219 .map(|model| SerializedLanguageModel {
1220 provider: model.provider.id().0.to_string(),
1221 model: model.model.id().0.to_string(),
1222 }),
1223 completion_mode: Some(this.completion_mode),
1224 tool_use_limit_reached: this.tool_use_limit_reached,
1225 profile: Some(this.profile.id().clone()),
1226 })
1227 })
1228 }
1229
1230 pub fn remaining_turns(&self) -> u32 {
1231 self.remaining_turns
1232 }
1233
1234 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235 self.remaining_turns = remaining_turns;
1236 }
1237
1238 pub fn send_to_model(
1239 &mut self,
1240 model: Arc<dyn LanguageModel>,
1241 intent: CompletionIntent,
1242 window: Option<AnyWindowHandle>,
1243 cx: &mut Context<Self>,
1244 ) {
1245 if self.remaining_turns == 0 {
1246 return;
1247 }
1248
1249 self.remaining_turns -= 1;
1250
1251 let request = self.to_completion_request(model.clone(), intent, cx);
1252
1253 self.stream_completion(request, model, intent, window, cx);
1254 }
1255
1256 pub fn used_tools_since_last_user_message(&self) -> bool {
1257 for message in self.messages.iter().rev() {
1258 if self.tool_use.message_has_tool_results(message.id) {
1259 return true;
1260 } else if message.role == Role::User {
1261 return false;
1262 }
1263 }
1264
1265 false
1266 }
1267
1268 pub fn to_completion_request(
1269 &self,
1270 model: Arc<dyn LanguageModel>,
1271 intent: CompletionIntent,
1272 cx: &mut Context<Self>,
1273 ) -> LanguageModelRequest {
1274 let mut request = LanguageModelRequest {
1275 thread_id: Some(self.id.to_string()),
1276 prompt_id: Some(self.last_prompt_id.to_string()),
1277 intent: Some(intent),
1278 mode: None,
1279 messages: vec![],
1280 tools: Vec::new(),
1281 tool_choice: None,
1282 stop: Vec::new(),
1283 temperature: AgentSettings::temperature_for_model(&model, cx),
1284 };
1285
1286 let available_tools = self.available_tools(cx, model.clone());
1287 let available_tool_names = available_tools
1288 .iter()
1289 .map(|tool| tool.name.clone())
1290 .collect();
1291
1292 let model_context = &ModelContext {
1293 available_tools: available_tool_names,
1294 };
1295
1296 if let Some(project_context) = self.project_context.borrow().as_ref() {
1297 match self
1298 .prompt_builder
1299 .generate_assistant_system_prompt(project_context, model_context)
1300 {
1301 Err(err) => {
1302 let message = format!("{err:?}").into();
1303 log::error!("{message}");
1304 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1305 header: "Error generating system prompt".into(),
1306 message,
1307 }));
1308 }
1309 Ok(system_prompt) => {
1310 request.messages.push(LanguageModelRequestMessage {
1311 role: Role::System,
1312 content: vec![MessageContent::Text(system_prompt)],
1313 cache: true,
1314 });
1315 }
1316 }
1317 } else {
1318 let message = "Context for system prompt unexpectedly not ready.".into();
1319 log::error!("{message}");
1320 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1321 header: "Error generating system prompt".into(),
1322 message,
1323 }));
1324 }
1325
1326 let mut message_ix_to_cache = None;
1327 for message in &self.messages {
1328 // ui_only messages are for the UI only, not for the model
1329 if message.ui_only {
1330 continue;
1331 }
1332
1333 let mut request_message = LanguageModelRequestMessage {
1334 role: message.role,
1335 content: Vec::new(),
1336 cache: false,
1337 };
1338
1339 message
1340 .loaded_context
1341 .add_to_request_message(&mut request_message);
1342
1343 for segment in &message.segments {
1344 match segment {
1345 MessageSegment::Text(text) => {
1346 if !text.is_empty() {
1347 request_message
1348 .content
1349 .push(MessageContent::Text(text.into()));
1350 }
1351 }
1352 MessageSegment::Thinking { text, signature } => {
1353 if !text.is_empty() {
1354 request_message.content.push(MessageContent::Thinking {
1355 text: text.into(),
1356 signature: signature.clone(),
1357 });
1358 }
1359 }
1360 MessageSegment::RedactedThinking(data) => {
1361 request_message
1362 .content
1363 .push(MessageContent::RedactedThinking(data.clone()));
1364 }
1365 };
1366 }
1367
1368 let mut cache_message = true;
1369 let mut tool_results_message = LanguageModelRequestMessage {
1370 role: Role::User,
1371 content: Vec::new(),
1372 cache: false,
1373 };
1374 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1375 if let Some(tool_result) = tool_result {
1376 request_message
1377 .content
1378 .push(MessageContent::ToolUse(tool_use.clone()));
1379 tool_results_message
1380 .content
1381 .push(MessageContent::ToolResult(LanguageModelToolResult {
1382 tool_use_id: tool_use.id.clone(),
1383 tool_name: tool_result.tool_name.clone(),
1384 is_error: tool_result.is_error,
1385 content: if tool_result.content.is_empty() {
1386 // Surprisingly, the API fails if we return an empty string here.
1387 // It thinks we are sending a tool use without a tool result.
1388 "<Tool returned an empty string>".into()
1389 } else {
1390 tool_result.content.clone()
1391 },
1392 output: None,
1393 }));
1394 } else {
1395 cache_message = false;
1396 log::debug!(
1397 "skipped tool use {:?} because it is still pending",
1398 tool_use
1399 );
1400 }
1401 }
1402
1403 if cache_message {
1404 message_ix_to_cache = Some(request.messages.len());
1405 }
1406 request.messages.push(request_message);
1407
1408 if !tool_results_message.content.is_empty() {
1409 if cache_message {
1410 message_ix_to_cache = Some(request.messages.len());
1411 }
1412 request.messages.push(tool_results_message);
1413 }
1414 }
1415
1416 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1417 if let Some(message_ix_to_cache) = message_ix_to_cache {
1418 request.messages[message_ix_to_cache].cache = true;
1419 }
1420
1421 request.tools = available_tools;
1422 request.mode = if model.supports_burn_mode() {
1423 Some(self.completion_mode.into())
1424 } else {
1425 Some(CompletionMode::Normal.into())
1426 };
1427
1428 request
1429 }
1430
1431 fn to_summarize_request(
1432 &self,
1433 model: &Arc<dyn LanguageModel>,
1434 intent: CompletionIntent,
1435 added_user_message: String,
1436 cx: &App,
1437 ) -> LanguageModelRequest {
1438 let mut request = LanguageModelRequest {
1439 thread_id: None,
1440 prompt_id: None,
1441 intent: Some(intent),
1442 mode: None,
1443 messages: vec![],
1444 tools: Vec::new(),
1445 tool_choice: None,
1446 stop: Vec::new(),
1447 temperature: AgentSettings::temperature_for_model(model, cx),
1448 };
1449
1450 for message in &self.messages {
1451 let mut request_message = LanguageModelRequestMessage {
1452 role: message.role,
1453 content: Vec::new(),
1454 cache: false,
1455 };
1456
1457 for segment in &message.segments {
1458 match segment {
1459 MessageSegment::Text(text) => request_message
1460 .content
1461 .push(MessageContent::Text(text.clone())),
1462 MessageSegment::Thinking { .. } => {}
1463 MessageSegment::RedactedThinking(_) => {}
1464 }
1465 }
1466
1467 if request_message.content.is_empty() {
1468 continue;
1469 }
1470
1471 request.messages.push(request_message);
1472 }
1473
1474 request.messages.push(LanguageModelRequestMessage {
1475 role: Role::User,
1476 content: vec![MessageContent::Text(added_user_message)],
1477 cache: false,
1478 });
1479
1480 request
1481 }
1482
1483 pub fn stream_completion(
1484 &mut self,
1485 request: LanguageModelRequest,
1486 model: Arc<dyn LanguageModel>,
1487 intent: CompletionIntent,
1488 window: Option<AnyWindowHandle>,
1489 cx: &mut Context<Self>,
1490 ) {
1491 self.tool_use_limit_reached = false;
1492
1493 let pending_completion_id = post_inc(&mut self.completion_count);
1494 let mut request_callback_parameters = if self.request_callback.is_some() {
1495 Some((request.clone(), Vec::new()))
1496 } else {
1497 None
1498 };
1499 let prompt_id = self.last_prompt_id.clone();
1500 let tool_use_metadata = ToolUseMetadata {
1501 model: model.clone(),
1502 thread_id: self.id.clone(),
1503 prompt_id: prompt_id.clone(),
1504 };
1505
1506 self.last_received_chunk_at = Some(Instant::now());
1507
1508 let task = cx.spawn(async move |thread, cx| {
1509 let stream_completion_future = model.stream_completion(request, &cx);
1510 let initial_token_usage =
1511 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1512 let stream_completion = async {
1513 let mut events = stream_completion_future.await?;
1514
1515 let mut stop_reason = StopReason::EndTurn;
1516 let mut current_token_usage = TokenUsage::default();
1517
1518 thread
1519 .update(cx, |_thread, cx| {
1520 cx.emit(ThreadEvent::NewRequest);
1521 })
1522 .ok();
1523
1524 let mut request_assistant_message_id = None;
1525
1526 while let Some(event) = events.next().await {
1527 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1528 response_events
1529 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1530 }
1531
1532 thread.update(cx, |thread, cx| {
1533 let event = match event {
1534 Ok(event) => event,
1535 Err(error) => {
1536 match error {
1537 LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1538 anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1539 }
1540 LanguageModelCompletionError::Overloaded => {
1541 anyhow::bail!(LanguageModelKnownError::Overloaded);
1542 }
1543 LanguageModelCompletionError::ApiInternalServerError =>{
1544 anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1545 }
1546 LanguageModelCompletionError::PromptTooLarge { tokens } => {
1547 let tokens = tokens.unwrap_or_else(|| {
1548 // We didn't get an exact token count from the API, so fall back on our estimate.
1549 thread.total_token_usage()
1550 .map(|usage| usage.total)
1551 .unwrap_or(0)
1552 // We know the context window was exceeded in practice, so if our estimate was
1553 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1554 .max(model.max_token_count().saturating_add(1))
1555 });
1556
1557 anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1558 }
1559 LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1560 anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1561 }
1562 LanguageModelCompletionError::UnknownResponseFormat(error) => {
1563 anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1564 }
1565 LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1566 if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1567 anyhow::bail!(known_error);
1568 } else {
1569 return Err(error.into());
1570 }
1571 }
1572 LanguageModelCompletionError::DeserializeResponse(error) => {
1573 anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1574 }
1575 LanguageModelCompletionError::BadInputJson {
1576 id,
1577 tool_name,
1578 raw_input: invalid_input_json,
1579 json_parse_error,
1580 } => {
1581 thread.receive_invalid_tool_json(
1582 id,
1583 tool_name,
1584 invalid_input_json,
1585 json_parse_error,
1586 window,
1587 cx,
1588 );
1589 return Ok(());
1590 }
1591 // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1592 err @ LanguageModelCompletionError::BadRequestFormat |
1593 err @ LanguageModelCompletionError::AuthenticationError |
1594 err @ LanguageModelCompletionError::PermissionError |
1595 err @ LanguageModelCompletionError::ApiEndpointNotFound |
1596 err @ LanguageModelCompletionError::SerializeRequest(_) |
1597 err @ LanguageModelCompletionError::BuildRequestBody(_) |
1598 err @ LanguageModelCompletionError::HttpSend(_) => {
1599 anyhow::bail!(err);
1600 }
1601 LanguageModelCompletionError::Other(error) => {
1602 return Err(error);
1603 }
1604 }
1605 }
1606 };
1607
1608 match event {
1609 LanguageModelCompletionEvent::StartMessage { .. } => {
1610 request_assistant_message_id =
1611 Some(thread.insert_assistant_message(
1612 vec![MessageSegment::Text(String::new())],
1613 cx,
1614 ));
1615 }
1616 LanguageModelCompletionEvent::Stop(reason) => {
1617 stop_reason = reason;
1618 }
1619 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1620 thread.update_token_usage_at_last_message(token_usage);
1621 thread.cumulative_token_usage = thread.cumulative_token_usage
1622 + token_usage
1623 - current_token_usage;
1624 current_token_usage = token_usage;
1625 }
1626 LanguageModelCompletionEvent::Text(chunk) => {
1627 thread.received_chunk();
1628
1629 cx.emit(ThreadEvent::ReceivedTextChunk);
1630 if let Some(last_message) = thread.messages.last_mut() {
1631 if last_message.role == Role::Assistant
1632 && !thread.tool_use.has_tool_results(last_message.id)
1633 {
1634 last_message.push_text(&chunk);
1635 cx.emit(ThreadEvent::StreamedAssistantText(
1636 last_message.id,
1637 chunk,
1638 ));
1639 } else {
1640 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1641 // of a new Assistant response.
1642 //
1643 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1644 // will result in duplicating the text of the chunk in the rendered Markdown.
1645 request_assistant_message_id =
1646 Some(thread.insert_assistant_message(
1647 vec![MessageSegment::Text(chunk.to_string())],
1648 cx,
1649 ));
1650 };
1651 }
1652 }
1653 LanguageModelCompletionEvent::Thinking {
1654 text: chunk,
1655 signature,
1656 } => {
1657 thread.received_chunk();
1658
1659 if let Some(last_message) = thread.messages.last_mut() {
1660 if last_message.role == Role::Assistant
1661 && !thread.tool_use.has_tool_results(last_message.id)
1662 {
1663 last_message.push_thinking(&chunk, signature);
1664 cx.emit(ThreadEvent::StreamedAssistantThinking(
1665 last_message.id,
1666 chunk,
1667 ));
1668 } else {
1669 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1670 // of a new Assistant response.
1671 //
1672 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1673 // will result in duplicating the text of the chunk in the rendered Markdown.
1674 request_assistant_message_id =
1675 Some(thread.insert_assistant_message(
1676 vec![MessageSegment::Thinking {
1677 text: chunk.to_string(),
1678 signature,
1679 }],
1680 cx,
1681 ));
1682 };
1683 }
1684 }
1685 LanguageModelCompletionEvent::RedactedThinking {
1686 data
1687 } => {
1688 thread.received_chunk();
1689
1690 if let Some(last_message) = thread.messages.last_mut() {
1691 if last_message.role == Role::Assistant
1692 && !thread.tool_use.has_tool_results(last_message.id)
1693 {
1694 last_message.push_redacted_thinking(data);
1695 } else {
1696 request_assistant_message_id =
1697 Some(thread.insert_assistant_message(
1698 vec![MessageSegment::RedactedThinking(data)],
1699 cx,
1700 ));
1701 };
1702 }
1703 }
1704 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1705 let last_assistant_message_id = request_assistant_message_id
1706 .unwrap_or_else(|| {
1707 let new_assistant_message_id =
1708 thread.insert_assistant_message(vec![], cx);
1709 request_assistant_message_id =
1710 Some(new_assistant_message_id);
1711 new_assistant_message_id
1712 });
1713
1714 let tool_use_id = tool_use.id.clone();
1715 let streamed_input = if tool_use.is_input_complete {
1716 None
1717 } else {
1718 Some((&tool_use.input).clone())
1719 };
1720
1721 let ui_text = thread.tool_use.request_tool_use(
1722 last_assistant_message_id,
1723 tool_use,
1724 tool_use_metadata.clone(),
1725 cx,
1726 );
1727
1728 if let Some(input) = streamed_input {
1729 cx.emit(ThreadEvent::StreamedToolUse {
1730 tool_use_id,
1731 ui_text,
1732 input,
1733 });
1734 }
1735 }
1736 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1737 if let Some(completion) = thread
1738 .pending_completions
1739 .iter_mut()
1740 .find(|completion| completion.id == pending_completion_id)
1741 {
1742 match status_update {
1743 CompletionRequestStatus::Queued {
1744 position,
1745 } => {
1746 completion.queue_state = QueueState::Queued { position };
1747 }
1748 CompletionRequestStatus::Started => {
1749 completion.queue_state = QueueState::Started;
1750 }
1751 CompletionRequestStatus::Failed {
1752 code, message, request_id
1753 } => {
1754 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1755 }
1756 CompletionRequestStatus::UsageUpdated {
1757 amount, limit
1758 } => {
1759 thread.update_model_request_usage(amount as u32, limit, cx);
1760 }
1761 CompletionRequestStatus::ToolUseLimitReached => {
1762 thread.tool_use_limit_reached = true;
1763 cx.emit(ThreadEvent::ToolUseLimitReached);
1764 }
1765 }
1766 }
1767 }
1768 }
1769
1770 thread.touch_updated_at();
1771 cx.emit(ThreadEvent::StreamedCompletion);
1772 cx.notify();
1773
1774 thread.auto_capture_telemetry(cx);
1775 Ok(())
1776 })??;
1777
1778 smol::future::yield_now().await;
1779 }
1780
1781 thread.update(cx, |thread, cx| {
1782 thread.last_received_chunk_at = None;
1783 thread
1784 .pending_completions
1785 .retain(|completion| completion.id != pending_completion_id);
1786
1787 // If there is a response without tool use, summarize the message. Otherwise,
1788 // allow two tool uses before summarizing.
1789 if matches!(thread.summary, ThreadSummary::Pending)
1790 && thread.messages.len() >= 2
1791 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1792 {
1793 thread.summarize(cx);
1794 }
1795 })?;
1796
1797 anyhow::Ok(stop_reason)
1798 };
1799
1800 let result = stream_completion.await;
1801 let mut retry_scheduled = false;
1802
1803 thread
1804 .update(cx, |thread, cx| {
1805 thread.finalize_pending_checkpoint(cx);
1806 match result.as_ref() {
1807 Ok(stop_reason) => {
1808 match stop_reason {
1809 StopReason::ToolUse => {
1810 let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
1811 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1812 }
1813 StopReason::EndTurn | StopReason::MaxTokens => {
1814 thread.project.update(cx, |project, cx| {
1815 project.set_agent_location(None, cx);
1816 });
1817 }
1818 StopReason::Refusal => {
1819 thread.project.update(cx, |project, cx| {
1820 project.set_agent_location(None, cx);
1821 });
1822
1823 // Remove the turn that was refused.
1824 //
1825 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1826 {
1827 let mut messages_to_remove = Vec::new();
1828
1829 for (ix, message) in thread.messages.iter().enumerate().rev() {
1830 messages_to_remove.push(message.id);
1831
1832 if message.role == Role::User {
1833 if ix == 0 {
1834 break;
1835 }
1836
1837 if let Some(prev_message) = thread.messages.get(ix - 1) {
1838 if prev_message.role == Role::Assistant {
1839 break;
1840 }
1841 }
1842 }
1843 }
1844
1845 for message_id in messages_to_remove {
1846 thread.delete_message(message_id, cx);
1847 }
1848 }
1849
1850 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1851 header: "Language model refusal".into(),
1852 message: "Model refused to generate content for safety reasons.".into(),
1853 }));
1854 }
1855 }
1856
1857 // We successfully completed, so cancel any remaining retries.
1858 thread.retry_state = None;
1859 },
1860 Err(error) => {
1861 thread.project.update(cx, |project, cx| {
1862 project.set_agent_location(None, cx);
1863 });
1864
1865 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1866 let error_message = error
1867 .chain()
1868 .map(|err| err.to_string())
1869 .collect::<Vec<_>>()
1870 .join("\n");
1871 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1872 header: "Error interacting with language model".into(),
1873 message: SharedString::from(error_message.clone()),
1874 }));
1875 }
1876
1877 if error.is::<PaymentRequiredError>() {
1878 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1879 } else if let Some(error) =
1880 error.downcast_ref::<ModelRequestLimitReachedError>()
1881 {
1882 cx.emit(ThreadEvent::ShowError(
1883 ThreadError::ModelRequestLimitReached { plan: error.plan },
1884 ));
1885 } else if let Some(known_error) =
1886 error.downcast_ref::<LanguageModelKnownError>()
1887 {
1888 match known_error {
1889 LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1890 thread.exceeded_window_error = Some(ExceededWindowError {
1891 model_id: model.id(),
1892 token_count: *tokens,
1893 });
1894 cx.notify();
1895 }
1896 LanguageModelKnownError::RateLimitExceeded { retry_after } => {
1897 let provider_name = model.provider_name();
1898 let error_message = format!(
1899 "{}'s API rate limit exceeded",
1900 provider_name.0.as_ref()
1901 );
1902
1903 thread.handle_rate_limit_error(
1904 &error_message,
1905 *retry_after,
1906 model.clone(),
1907 intent,
1908 window,
1909 cx,
1910 );
1911 retry_scheduled = true;
1912 }
1913 LanguageModelKnownError::Overloaded => {
1914 let provider_name = model.provider_name();
1915 let error_message = format!(
1916 "{}'s API servers are overloaded right now",
1917 provider_name.0.as_ref()
1918 );
1919
1920 retry_scheduled = thread.handle_retryable_error(
1921 &error_message,
1922 model.clone(),
1923 intent,
1924 window,
1925 cx,
1926 );
1927 if !retry_scheduled {
1928 emit_generic_error(error, cx);
1929 }
1930 }
1931 LanguageModelKnownError::ApiInternalServerError => {
1932 let provider_name = model.provider_name();
1933 let error_message = format!(
1934 "{}'s API server reported an internal server error",
1935 provider_name.0.as_ref()
1936 );
1937
1938 retry_scheduled = thread.handle_retryable_error(
1939 &error_message,
1940 model.clone(),
1941 intent,
1942 window,
1943 cx,
1944 );
1945 if !retry_scheduled {
1946 emit_generic_error(error, cx);
1947 }
1948 }
1949 LanguageModelKnownError::ReadResponseError(_) |
1950 LanguageModelKnownError::DeserializeResponse(_) |
1951 LanguageModelKnownError::UnknownResponseFormat(_) => {
1952 // In the future we will attempt to re-roll response, but only once
1953 emit_generic_error(error, cx);
1954 }
1955 }
1956 } else {
1957 emit_generic_error(error, cx);
1958 }
1959
1960 if !retry_scheduled {
1961 thread.cancel_last_completion(window, cx);
1962 }
1963 }
1964 }
1965
1966 if !retry_scheduled {
1967 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1968 }
1969
1970 if let Some((request_callback, (request, response_events))) = thread
1971 .request_callback
1972 .as_mut()
1973 .zip(request_callback_parameters.as_ref())
1974 {
1975 request_callback(request, response_events);
1976 }
1977
1978 thread.auto_capture_telemetry(cx);
1979
1980 if let Ok(initial_usage) = initial_token_usage {
1981 let usage = thread.cumulative_token_usage - initial_usage;
1982
1983 telemetry::event!(
1984 "Assistant Thread Completion",
1985 thread_id = thread.id().to_string(),
1986 prompt_id = prompt_id,
1987 model = model.telemetry_id(),
1988 model_provider = model.provider_id().to_string(),
1989 input_tokens = usage.input_tokens,
1990 output_tokens = usage.output_tokens,
1991 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1992 cache_read_input_tokens = usage.cache_read_input_tokens,
1993 );
1994 }
1995 })
1996 .ok();
1997 });
1998
1999 self.pending_completions.push(PendingCompletion {
2000 id: pending_completion_id,
2001 queue_state: QueueState::Sending,
2002 _task: task,
2003 });
2004 }
2005
2006 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2007 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2008 println!("No thread summary model");
2009 return;
2010 };
2011
2012 if !model.provider.is_authenticated(cx) {
2013 return;
2014 }
2015
2016 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2017
2018 let request = self.to_summarize_request(
2019 &model.model,
2020 CompletionIntent::ThreadSummarization,
2021 added_user_message.into(),
2022 cx,
2023 );
2024
2025 self.summary = ThreadSummary::Generating;
2026
2027 self.pending_summary = cx.spawn(async move |this, cx| {
2028 let result = async {
2029 let mut messages = model.model.stream_completion(request, &cx).await?;
2030
2031 let mut new_summary = String::new();
2032 while let Some(event) = messages.next().await {
2033 let Ok(event) = event else {
2034 continue;
2035 };
2036 let text = match event {
2037 LanguageModelCompletionEvent::Text(text) => text,
2038 LanguageModelCompletionEvent::StatusUpdate(
2039 CompletionRequestStatus::UsageUpdated { amount, limit },
2040 ) => {
2041 this.update(cx, |thread, cx| {
2042 thread.update_model_request_usage(amount as u32, limit, cx);
2043 })?;
2044 continue;
2045 }
2046 _ => continue,
2047 };
2048
2049 let mut lines = text.lines();
2050 new_summary.extend(lines.next());
2051
2052 // Stop if the LLM generated multiple lines.
2053 if lines.next().is_some() {
2054 break;
2055 }
2056 }
2057
2058 anyhow::Ok(new_summary)
2059 }
2060 .await;
2061
2062 this.update(cx, |this, cx| {
2063 match result {
2064 Ok(new_summary) => {
2065 if new_summary.is_empty() {
2066 this.summary = ThreadSummary::Error;
2067 } else {
2068 this.summary = ThreadSummary::Ready(new_summary.into());
2069 }
2070 }
2071 Err(err) => {
2072 this.summary = ThreadSummary::Error;
2073 log::error!("Failed to generate thread summary: {}", err);
2074 }
2075 }
2076 cx.emit(ThreadEvent::SummaryGenerated);
2077 })
2078 .log_err()?;
2079
2080 Some(())
2081 });
2082 }
2083
2084 fn handle_rate_limit_error(
2085 &mut self,
2086 error_message: &str,
2087 retry_after: Duration,
2088 model: Arc<dyn LanguageModel>,
2089 intent: CompletionIntent,
2090 window: Option<AnyWindowHandle>,
2091 cx: &mut Context<Self>,
2092 ) {
2093 // For rate limit errors, we only retry once with the specified duration
2094 let retry_message = format!(
2095 "{error_message}. Retrying in {} seconds…",
2096 retry_after.as_secs()
2097 );
2098
2099 // Add a UI-only message instead of a regular message
2100 let id = self.next_message_id.post_inc();
2101 self.messages.push(Message {
2102 id,
2103 role: Role::System,
2104 segments: vec![MessageSegment::Text(retry_message)],
2105 loaded_context: LoadedContext::default(),
2106 creases: Vec::new(),
2107 is_hidden: false,
2108 ui_only: true,
2109 });
2110 cx.emit(ThreadEvent::MessageAdded(id));
2111 // Schedule the retry
2112 let thread_handle = cx.entity().downgrade();
2113
2114 cx.spawn(async move |_thread, cx| {
2115 cx.background_executor().timer(retry_after).await;
2116
2117 thread_handle
2118 .update(cx, |thread, cx| {
2119 // Retry the completion
2120 thread.send_to_model(model, intent, window, cx);
2121 })
2122 .log_err();
2123 })
2124 .detach();
2125 }
2126
2127 fn handle_retryable_error(
2128 &mut self,
2129 error_message: &str,
2130 model: Arc<dyn LanguageModel>,
2131 intent: CompletionIntent,
2132 window: Option<AnyWindowHandle>,
2133 cx: &mut Context<Self>,
2134 ) -> bool {
2135 self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
2136 }
2137
2138 fn handle_retryable_error_with_delay(
2139 &mut self,
2140 error_message: &str,
2141 custom_delay: Option<Duration>,
2142 model: Arc<dyn LanguageModel>,
2143 intent: CompletionIntent,
2144 window: Option<AnyWindowHandle>,
2145 cx: &mut Context<Self>,
2146 ) -> bool {
2147 let retry_state = self.retry_state.get_or_insert(RetryState {
2148 attempt: 0,
2149 max_attempts: MAX_RETRY_ATTEMPTS,
2150 intent,
2151 });
2152
2153 retry_state.attempt += 1;
2154 let attempt = retry_state.attempt;
2155 let max_attempts = retry_state.max_attempts;
2156 let intent = retry_state.intent;
2157
2158 if attempt <= max_attempts {
2159 // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2160 let delay = if let Some(custom_delay) = custom_delay {
2161 custom_delay
2162 } else {
2163 let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2164 Duration::from_secs(delay_secs)
2165 };
2166
2167 // Add a transient message to inform the user
2168 let delay_secs = delay.as_secs();
2169 let retry_message = format!(
2170 "{}. Retrying (attempt {} of {}) in {} seconds...",
2171 error_message, attempt, max_attempts, delay_secs
2172 );
2173
2174 // Add a UI-only message instead of a regular message
2175 let id = self.next_message_id.post_inc();
2176 self.messages.push(Message {
2177 id,
2178 role: Role::System,
2179 segments: vec![MessageSegment::Text(retry_message)],
2180 loaded_context: LoadedContext::default(),
2181 creases: Vec::new(),
2182 is_hidden: false,
2183 ui_only: true,
2184 });
2185 cx.emit(ThreadEvent::MessageAdded(id));
2186
2187 // Schedule the retry
2188 let thread_handle = cx.entity().downgrade();
2189
2190 cx.spawn(async move |_thread, cx| {
2191 cx.background_executor().timer(delay).await;
2192
2193 thread_handle
2194 .update(cx, |thread, cx| {
2195 // Retry the completion
2196 thread.send_to_model(model, intent, window, cx);
2197 })
2198 .log_err();
2199 })
2200 .detach();
2201
2202 true
2203 } else {
2204 // Max retries exceeded
2205 self.retry_state = None;
2206
2207 let notification_text = if max_attempts == 1 {
2208 "Failed after retrying.".into()
2209 } else {
2210 format!("Failed after retrying {} times.", max_attempts).into()
2211 };
2212
2213 // Stop generating since we're giving up on retrying.
2214 self.pending_completions.clear();
2215
2216 cx.emit(ThreadEvent::RetriesFailed {
2217 message: notification_text,
2218 });
2219
2220 false
2221 }
2222 }
2223
2224 pub fn start_generating_detailed_summary_if_needed(
2225 &mut self,
2226 thread_store: WeakEntity<ThreadStore>,
2227 cx: &mut Context<Self>,
2228 ) {
2229 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2230 return;
2231 };
2232
2233 match &*self.detailed_summary_rx.borrow() {
2234 DetailedSummaryState::Generating { message_id, .. }
2235 | DetailedSummaryState::Generated { message_id, .. }
2236 if *message_id == last_message_id =>
2237 {
2238 // Already up-to-date
2239 return;
2240 }
2241 _ => {}
2242 }
2243
2244 let Some(ConfiguredModel { model, provider }) =
2245 LanguageModelRegistry::read_global(cx).thread_summary_model()
2246 else {
2247 return;
2248 };
2249
2250 if !provider.is_authenticated(cx) {
2251 return;
2252 }
2253
2254 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2255
2256 let request = self.to_summarize_request(
2257 &model,
2258 CompletionIntent::ThreadContextSummarization,
2259 added_user_message.into(),
2260 cx,
2261 );
2262
2263 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2264 message_id: last_message_id,
2265 };
2266
2267 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2268 // be better to allow the old task to complete, but this would require logic for choosing
2269 // which result to prefer (the old task could complete after the new one, resulting in a
2270 // stale summary).
2271 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2272 let stream = model.stream_completion_text(request, &cx);
2273 let Some(mut messages) = stream.await.log_err() else {
2274 thread
2275 .update(cx, |thread, _cx| {
2276 *thread.detailed_summary_tx.borrow_mut() =
2277 DetailedSummaryState::NotGenerated;
2278 })
2279 .ok()?;
2280 return None;
2281 };
2282
2283 let mut new_detailed_summary = String::new();
2284
2285 while let Some(chunk) = messages.stream.next().await {
2286 if let Some(chunk) = chunk.log_err() {
2287 new_detailed_summary.push_str(&chunk);
2288 }
2289 }
2290
2291 thread
2292 .update(cx, |thread, _cx| {
2293 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2294 text: new_detailed_summary.into(),
2295 message_id: last_message_id,
2296 };
2297 })
2298 .ok()?;
2299
2300 // Save thread so its summary can be reused later
2301 if let Some(thread) = thread.upgrade() {
2302 if let Ok(Ok(save_task)) = cx.update(|cx| {
2303 thread_store
2304 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2305 }) {
2306 save_task.await.log_err();
2307 }
2308 }
2309
2310 Some(())
2311 });
2312 }
2313
2314 pub async fn wait_for_detailed_summary_or_text(
2315 this: &Entity<Self>,
2316 cx: &mut AsyncApp,
2317 ) -> Option<SharedString> {
2318 let mut detailed_summary_rx = this
2319 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2320 .ok()?;
2321 loop {
2322 match detailed_summary_rx.recv().await? {
2323 DetailedSummaryState::Generating { .. } => {}
2324 DetailedSummaryState::NotGenerated => {
2325 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2326 }
2327 DetailedSummaryState::Generated { text, .. } => return Some(text),
2328 }
2329 }
2330 }
2331
2332 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2333 self.detailed_summary_rx
2334 .borrow()
2335 .text()
2336 .unwrap_or_else(|| self.text().into())
2337 }
2338
2339 pub fn is_generating_detailed_summary(&self) -> bool {
2340 matches!(
2341 &*self.detailed_summary_rx.borrow(),
2342 DetailedSummaryState::Generating { .. }
2343 )
2344 }
2345
2346 pub fn use_pending_tools(
2347 &mut self,
2348 window: Option<AnyWindowHandle>,
2349 model: Arc<dyn LanguageModel>,
2350 cx: &mut Context<Self>,
2351 ) -> Vec<PendingToolUse> {
2352 self.auto_capture_telemetry(cx);
2353 let request =
2354 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2355 let pending_tool_uses = self
2356 .tool_use
2357 .pending_tool_uses()
2358 .into_iter()
2359 .filter(|tool_use| tool_use.status.is_idle())
2360 .cloned()
2361 .collect::<Vec<_>>();
2362
2363 for tool_use in pending_tool_uses.iter() {
2364 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2365 }
2366
2367 pending_tool_uses
2368 }
2369
2370 fn use_pending_tool(
2371 &mut self,
2372 tool_use: PendingToolUse,
2373 request: Arc<LanguageModelRequest>,
2374 model: Arc<dyn LanguageModel>,
2375 window: Option<AnyWindowHandle>,
2376 cx: &mut Context<Self>,
2377 ) {
2378 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2379 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2380 };
2381
2382 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2383 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2384 }
2385
2386 if tool.needs_confirmation(&tool_use.input, cx)
2387 && !AgentSettings::get_global(cx).always_allow_tool_actions
2388 {
2389 self.tool_use.confirm_tool_use(
2390 tool_use.id,
2391 tool_use.ui_text,
2392 tool_use.input,
2393 request,
2394 tool,
2395 );
2396 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2397 } else {
2398 self.run_tool(
2399 tool_use.id,
2400 tool_use.ui_text,
2401 tool_use.input,
2402 request,
2403 tool,
2404 model,
2405 window,
2406 cx,
2407 );
2408 }
2409 }
2410
2411 pub fn handle_hallucinated_tool_use(
2412 &mut self,
2413 tool_use_id: LanguageModelToolUseId,
2414 hallucinated_tool_name: Arc<str>,
2415 window: Option<AnyWindowHandle>,
2416 cx: &mut Context<Thread>,
2417 ) {
2418 let available_tools = self.profile.enabled_tools(cx);
2419
2420 let tool_list = available_tools
2421 .iter()
2422 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2423 .collect::<Vec<_>>()
2424 .join("\n");
2425
2426 let error_message = format!(
2427 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2428 hallucinated_tool_name, tool_list
2429 );
2430
2431 let pending_tool_use = self.tool_use.insert_tool_output(
2432 tool_use_id.clone(),
2433 hallucinated_tool_name,
2434 Err(anyhow!("Missing tool call: {error_message}")),
2435 self.configured_model.as_ref(),
2436 );
2437
2438 cx.emit(ThreadEvent::MissingToolUse {
2439 tool_use_id: tool_use_id.clone(),
2440 ui_text: error_message.into(),
2441 });
2442
2443 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2444 }
2445
2446 pub fn receive_invalid_tool_json(
2447 &mut self,
2448 tool_use_id: LanguageModelToolUseId,
2449 tool_name: Arc<str>,
2450 invalid_json: Arc<str>,
2451 error: String,
2452 window: Option<AnyWindowHandle>,
2453 cx: &mut Context<Thread>,
2454 ) {
2455 log::error!("The model returned invalid input JSON: {invalid_json}");
2456
2457 let pending_tool_use = self.tool_use.insert_tool_output(
2458 tool_use_id.clone(),
2459 tool_name,
2460 Err(anyhow!("Error parsing input JSON: {error}")),
2461 self.configured_model.as_ref(),
2462 );
2463 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2464 pending_tool_use.ui_text.clone()
2465 } else {
2466 log::error!(
2467 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2468 );
2469 format!("Unknown tool {}", tool_use_id).into()
2470 };
2471
2472 cx.emit(ThreadEvent::InvalidToolInput {
2473 tool_use_id: tool_use_id.clone(),
2474 ui_text,
2475 invalid_input_json: invalid_json,
2476 });
2477
2478 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2479 }
2480
2481 pub fn run_tool(
2482 &mut self,
2483 tool_use_id: LanguageModelToolUseId,
2484 ui_text: impl Into<SharedString>,
2485 input: serde_json::Value,
2486 request: Arc<LanguageModelRequest>,
2487 tool: Arc<dyn Tool>,
2488 model: Arc<dyn LanguageModel>,
2489 window: Option<AnyWindowHandle>,
2490 cx: &mut Context<Thread>,
2491 ) {
2492 let task =
2493 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2494 self.tool_use
2495 .run_pending_tool(tool_use_id, ui_text.into(), task);
2496 }
2497
2498 fn spawn_tool_use(
2499 &mut self,
2500 tool_use_id: LanguageModelToolUseId,
2501 request: Arc<LanguageModelRequest>,
2502 input: serde_json::Value,
2503 tool: Arc<dyn Tool>,
2504 model: Arc<dyn LanguageModel>,
2505 window: Option<AnyWindowHandle>,
2506 cx: &mut Context<Thread>,
2507 ) -> Task<()> {
2508 let tool_name: Arc<str> = tool.name().into();
2509
2510 let tool_result = tool.run(
2511 input,
2512 request,
2513 self.project.clone(),
2514 self.action_log.clone(),
2515 model,
2516 window,
2517 cx,
2518 );
2519
2520 // Store the card separately if it exists
2521 if let Some(card) = tool_result.card.clone() {
2522 self.tool_use
2523 .insert_tool_result_card(tool_use_id.clone(), card);
2524 }
2525
2526 cx.spawn({
2527 async move |thread: WeakEntity<Thread>, cx| {
2528 let output = tool_result.output.await;
2529
2530 thread
2531 .update(cx, |thread, cx| {
2532 let pending_tool_use = thread.tool_use.insert_tool_output(
2533 tool_use_id.clone(),
2534 tool_name,
2535 output,
2536 thread.configured_model.as_ref(),
2537 );
2538 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2539 })
2540 .ok();
2541 }
2542 })
2543 }
2544
2545 fn tool_finished(
2546 &mut self,
2547 tool_use_id: LanguageModelToolUseId,
2548 pending_tool_use: Option<PendingToolUse>,
2549 canceled: bool,
2550 window: Option<AnyWindowHandle>,
2551 cx: &mut Context<Self>,
2552 ) {
2553 if self.all_tools_finished() {
2554 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2555 if !canceled {
2556 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2557 }
2558 self.auto_capture_telemetry(cx);
2559 }
2560 }
2561
2562 cx.emit(ThreadEvent::ToolFinished {
2563 tool_use_id,
2564 pending_tool_use,
2565 });
2566 }
2567
2568 /// Cancels the last pending completion, if there are any pending.
2569 ///
2570 /// Returns whether a completion was canceled.
2571 pub fn cancel_last_completion(
2572 &mut self,
2573 window: Option<AnyWindowHandle>,
2574 cx: &mut Context<Self>,
2575 ) -> bool {
2576 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2577
2578 self.retry_state = None;
2579
2580 for pending_tool_use in self.tool_use.cancel_pending() {
2581 canceled = true;
2582 self.tool_finished(
2583 pending_tool_use.id.clone(),
2584 Some(pending_tool_use),
2585 true,
2586 window,
2587 cx,
2588 );
2589 }
2590
2591 if canceled {
2592 cx.emit(ThreadEvent::CompletionCanceled);
2593
2594 // When canceled, we always want to insert the checkpoint.
2595 // (We skip over finalize_pending_checkpoint, because it
2596 // would conclude we didn't have anything to insert here.)
2597 if let Some(checkpoint) = self.pending_checkpoint.take() {
2598 self.insert_checkpoint(checkpoint, cx);
2599 }
2600 } else {
2601 self.finalize_pending_checkpoint(cx);
2602 }
2603
2604 canceled
2605 }
2606
2607 /// Signals that any in-progress editing should be canceled.
2608 ///
2609 /// This method is used to notify listeners (like ActiveThread) that
2610 /// they should cancel any editing operations.
2611 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2612 cx.emit(ThreadEvent::CancelEditing);
2613 }
2614
2615 pub fn feedback(&self) -> Option<ThreadFeedback> {
2616 self.feedback
2617 }
2618
2619 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2620 self.message_feedback.get(&message_id).copied()
2621 }
2622
2623 pub fn report_message_feedback(
2624 &mut self,
2625 message_id: MessageId,
2626 feedback: ThreadFeedback,
2627 cx: &mut Context<Self>,
2628 ) -> Task<Result<()>> {
2629 if self.message_feedback.get(&message_id) == Some(&feedback) {
2630 return Task::ready(Ok(()));
2631 }
2632
2633 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2634 let serialized_thread = self.serialize(cx);
2635 let thread_id = self.id().clone();
2636 let client = self.project.read(cx).client();
2637
2638 let enabled_tool_names: Vec<String> = self
2639 .profile
2640 .enabled_tools(cx)
2641 .iter()
2642 .map(|tool| tool.name())
2643 .collect();
2644
2645 self.message_feedback.insert(message_id, feedback);
2646
2647 cx.notify();
2648
2649 let message_content = self
2650 .message(message_id)
2651 .map(|msg| msg.to_string())
2652 .unwrap_or_default();
2653
2654 cx.background_spawn(async move {
2655 let final_project_snapshot = final_project_snapshot.await;
2656 let serialized_thread = serialized_thread.await?;
2657 let thread_data =
2658 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2659
2660 let rating = match feedback {
2661 ThreadFeedback::Positive => "positive",
2662 ThreadFeedback::Negative => "negative",
2663 };
2664 telemetry::event!(
2665 "Assistant Thread Rated",
2666 rating,
2667 thread_id,
2668 enabled_tool_names,
2669 message_id = message_id.0,
2670 message_content,
2671 thread_data,
2672 final_project_snapshot
2673 );
2674 client.telemetry().flush_events().await;
2675
2676 Ok(())
2677 })
2678 }
2679
2680 pub fn report_feedback(
2681 &mut self,
2682 feedback: ThreadFeedback,
2683 cx: &mut Context<Self>,
2684 ) -> Task<Result<()>> {
2685 let last_assistant_message_id = self
2686 .messages
2687 .iter()
2688 .rev()
2689 .find(|msg| msg.role == Role::Assistant)
2690 .map(|msg| msg.id);
2691
2692 if let Some(message_id) = last_assistant_message_id {
2693 self.report_message_feedback(message_id, feedback, cx)
2694 } else {
2695 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2696 let serialized_thread = self.serialize(cx);
2697 let thread_id = self.id().clone();
2698 let client = self.project.read(cx).client();
2699 self.feedback = Some(feedback);
2700 cx.notify();
2701
2702 cx.background_spawn(async move {
2703 let final_project_snapshot = final_project_snapshot.await;
2704 let serialized_thread = serialized_thread.await?;
2705 let thread_data = serde_json::to_value(serialized_thread)
2706 .unwrap_or_else(|_| serde_json::Value::Null);
2707
2708 let rating = match feedback {
2709 ThreadFeedback::Positive => "positive",
2710 ThreadFeedback::Negative => "negative",
2711 };
2712 telemetry::event!(
2713 "Assistant Thread Rated",
2714 rating,
2715 thread_id,
2716 thread_data,
2717 final_project_snapshot
2718 );
2719 client.telemetry().flush_events().await;
2720
2721 Ok(())
2722 })
2723 }
2724 }
2725
2726 /// Create a snapshot of the current project state including git information and unsaved buffers.
2727 fn project_snapshot(
2728 project: Entity<Project>,
2729 cx: &mut Context<Self>,
2730 ) -> Task<Arc<ProjectSnapshot>> {
2731 let git_store = project.read(cx).git_store().clone();
2732 let worktree_snapshots: Vec<_> = project
2733 .read(cx)
2734 .visible_worktrees(cx)
2735 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2736 .collect();
2737
2738 cx.spawn(async move |_, cx| {
2739 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2740
2741 let mut unsaved_buffers = Vec::new();
2742 cx.update(|app_cx| {
2743 let buffer_store = project.read(app_cx).buffer_store();
2744 for buffer_handle in buffer_store.read(app_cx).buffers() {
2745 let buffer = buffer_handle.read(app_cx);
2746 if buffer.is_dirty() {
2747 if let Some(file) = buffer.file() {
2748 let path = file.path().to_string_lossy().to_string();
2749 unsaved_buffers.push(path);
2750 }
2751 }
2752 }
2753 })
2754 .ok();
2755
2756 Arc::new(ProjectSnapshot {
2757 worktree_snapshots,
2758 unsaved_buffer_paths: unsaved_buffers,
2759 timestamp: Utc::now(),
2760 })
2761 })
2762 }
2763
2764 fn worktree_snapshot(
2765 worktree: Entity<project::Worktree>,
2766 git_store: Entity<GitStore>,
2767 cx: &App,
2768 ) -> Task<WorktreeSnapshot> {
2769 cx.spawn(async move |cx| {
2770 // Get worktree path and snapshot
2771 let worktree_info = cx.update(|app_cx| {
2772 let worktree = worktree.read(app_cx);
2773 let path = worktree.abs_path().to_string_lossy().to_string();
2774 let snapshot = worktree.snapshot();
2775 (path, snapshot)
2776 });
2777
2778 let Ok((worktree_path, _snapshot)) = worktree_info else {
2779 return WorktreeSnapshot {
2780 worktree_path: String::new(),
2781 git_state: None,
2782 };
2783 };
2784
2785 let git_state = git_store
2786 .update(cx, |git_store, cx| {
2787 git_store
2788 .repositories()
2789 .values()
2790 .find(|repo| {
2791 repo.read(cx)
2792 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2793 .is_some()
2794 })
2795 .cloned()
2796 })
2797 .ok()
2798 .flatten()
2799 .map(|repo| {
2800 repo.update(cx, |repo, _| {
2801 let current_branch =
2802 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2803 repo.send_job(None, |state, _| async move {
2804 let RepositoryState::Local { backend, .. } = state else {
2805 return GitState {
2806 remote_url: None,
2807 head_sha: None,
2808 current_branch,
2809 diff: None,
2810 };
2811 };
2812
2813 let remote_url = backend.remote_url("origin");
2814 let head_sha = backend.head_sha().await;
2815 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2816
2817 GitState {
2818 remote_url,
2819 head_sha,
2820 current_branch,
2821 diff,
2822 }
2823 })
2824 })
2825 });
2826
2827 let git_state = match git_state {
2828 Some(git_state) => match git_state.ok() {
2829 Some(git_state) => git_state.await.ok(),
2830 None => None,
2831 },
2832 None => None,
2833 };
2834
2835 WorktreeSnapshot {
2836 worktree_path,
2837 git_state,
2838 }
2839 })
2840 }
2841
2842 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2843 let mut markdown = Vec::new();
2844
2845 let summary = self.summary().or_default();
2846 writeln!(markdown, "# {summary}\n")?;
2847
2848 for message in self.messages() {
2849 writeln!(
2850 markdown,
2851 "## {role}\n",
2852 role = match message.role {
2853 Role::User => "User",
2854 Role::Assistant => "Agent",
2855 Role::System => "System",
2856 }
2857 )?;
2858
2859 if !message.loaded_context.text.is_empty() {
2860 writeln!(markdown, "{}", message.loaded_context.text)?;
2861 }
2862
2863 if !message.loaded_context.images.is_empty() {
2864 writeln!(
2865 markdown,
2866 "\n{} images attached as context.\n",
2867 message.loaded_context.images.len()
2868 )?;
2869 }
2870
2871 for segment in &message.segments {
2872 match segment {
2873 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2874 MessageSegment::Thinking { text, .. } => {
2875 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2876 }
2877 MessageSegment::RedactedThinking(_) => {}
2878 }
2879 }
2880
2881 for tool_use in self.tool_uses_for_message(message.id, cx) {
2882 writeln!(
2883 markdown,
2884 "**Use Tool: {} ({})**",
2885 tool_use.name, tool_use.id
2886 )?;
2887 writeln!(markdown, "```json")?;
2888 writeln!(
2889 markdown,
2890 "{}",
2891 serde_json::to_string_pretty(&tool_use.input)?
2892 )?;
2893 writeln!(markdown, "```")?;
2894 }
2895
2896 for tool_result in self.tool_results_for_message(message.id) {
2897 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2898 if tool_result.is_error {
2899 write!(markdown, " (Error)")?;
2900 }
2901
2902 writeln!(markdown, "**\n")?;
2903 match &tool_result.content {
2904 LanguageModelToolResultContent::Text(text) => {
2905 writeln!(markdown, "{text}")?;
2906 }
2907 LanguageModelToolResultContent::Image(image) => {
2908 writeln!(markdown, "", image.source)?;
2909 }
2910 }
2911
2912 if let Some(output) = tool_result.output.as_ref() {
2913 writeln!(
2914 markdown,
2915 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2916 serde_json::to_string_pretty(output)?
2917 )?;
2918 }
2919 }
2920 }
2921
2922 Ok(String::from_utf8_lossy(&markdown).to_string())
2923 }
2924
2925 pub fn keep_edits_in_range(
2926 &mut self,
2927 buffer: Entity<language::Buffer>,
2928 buffer_range: Range<language::Anchor>,
2929 cx: &mut Context<Self>,
2930 ) {
2931 self.action_log.update(cx, |action_log, cx| {
2932 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2933 });
2934 }
2935
2936 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2937 self.action_log
2938 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2939 }
2940
2941 pub fn reject_edits_in_ranges(
2942 &mut self,
2943 buffer: Entity<language::Buffer>,
2944 buffer_ranges: Vec<Range<language::Anchor>>,
2945 cx: &mut Context<Self>,
2946 ) -> Task<Result<()>> {
2947 self.action_log.update(cx, |action_log, cx| {
2948 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2949 })
2950 }
2951
2952 pub fn action_log(&self) -> &Entity<ActionLog> {
2953 &self.action_log
2954 }
2955
2956 pub fn project(&self) -> &Entity<Project> {
2957 &self.project
2958 }
2959
2960 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2961 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2962 return;
2963 }
2964
2965 let now = Instant::now();
2966 if let Some(last) = self.last_auto_capture_at {
2967 if now.duration_since(last).as_secs() < 10 {
2968 return;
2969 }
2970 }
2971
2972 self.last_auto_capture_at = Some(now);
2973
2974 let thread_id = self.id().clone();
2975 let github_login = self
2976 .project
2977 .read(cx)
2978 .user_store()
2979 .read(cx)
2980 .current_user()
2981 .map(|user| user.github_login.clone());
2982 let client = self.project.read(cx).client();
2983 let serialize_task = self.serialize(cx);
2984
2985 cx.background_executor()
2986 .spawn(async move {
2987 if let Ok(serialized_thread) = serialize_task.await {
2988 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2989 telemetry::event!(
2990 "Agent Thread Auto-Captured",
2991 thread_id = thread_id.to_string(),
2992 thread_data = thread_data,
2993 auto_capture_reason = "tracked_user",
2994 github_login = github_login
2995 );
2996
2997 client.telemetry().flush_events().await;
2998 }
2999 }
3000 })
3001 .detach();
3002 }
3003
3004 pub fn cumulative_token_usage(&self) -> TokenUsage {
3005 self.cumulative_token_usage
3006 }
3007
3008 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3009 let Some(model) = self.configured_model.as_ref() else {
3010 return TotalTokenUsage::default();
3011 };
3012
3013 let max = model.model.max_token_count();
3014
3015 let index = self
3016 .messages
3017 .iter()
3018 .position(|msg| msg.id == message_id)
3019 .unwrap_or(0);
3020
3021 if index == 0 {
3022 return TotalTokenUsage { total: 0, max };
3023 }
3024
3025 let token_usage = &self
3026 .request_token_usage
3027 .get(index - 1)
3028 .cloned()
3029 .unwrap_or_default();
3030
3031 TotalTokenUsage {
3032 total: token_usage.total_tokens(),
3033 max,
3034 }
3035 }
3036
3037 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3038 let model = self.configured_model.as_ref()?;
3039
3040 let max = model.model.max_token_count();
3041
3042 if let Some(exceeded_error) = &self.exceeded_window_error {
3043 if model.model.id() == exceeded_error.model_id {
3044 return Some(TotalTokenUsage {
3045 total: exceeded_error.token_count,
3046 max,
3047 });
3048 }
3049 }
3050
3051 let total = self
3052 .token_usage_at_last_message()
3053 .unwrap_or_default()
3054 .total_tokens();
3055
3056 Some(TotalTokenUsage { total, max })
3057 }
3058
3059 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3060 self.request_token_usage
3061 .get(self.messages.len().saturating_sub(1))
3062 .or_else(|| self.request_token_usage.last())
3063 .cloned()
3064 }
3065
3066 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3067 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3068 self.request_token_usage
3069 .resize(self.messages.len(), placeholder);
3070
3071 if let Some(last) = self.request_token_usage.last_mut() {
3072 *last = token_usage;
3073 }
3074 }
3075
3076 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3077 self.project.update(cx, |project, cx| {
3078 project.user_store().update(cx, |user_store, cx| {
3079 user_store.update_model_request_usage(
3080 ModelRequestUsage(RequestUsage {
3081 amount: amount as i32,
3082 limit,
3083 }),
3084 cx,
3085 )
3086 })
3087 });
3088 }
3089
3090 pub fn deny_tool_use(
3091 &mut self,
3092 tool_use_id: LanguageModelToolUseId,
3093 tool_name: Arc<str>,
3094 window: Option<AnyWindowHandle>,
3095 cx: &mut Context<Self>,
3096 ) {
3097 let err = Err(anyhow::anyhow!(
3098 "Permission to run tool action denied by user"
3099 ));
3100
3101 self.tool_use.insert_tool_output(
3102 tool_use_id.clone(),
3103 tool_name,
3104 err,
3105 self.configured_model.as_ref(),
3106 );
3107 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3108 }
3109}
3110
3111#[derive(Debug, Clone, Error)]
3112pub enum ThreadError {
3113 #[error("Payment required")]
3114 PaymentRequired,
3115 #[error("Model request limit reached")]
3116 ModelRequestLimitReached { plan: Plan },
3117 #[error("Message {header}: {message}")]
3118 Message {
3119 header: SharedString,
3120 message: SharedString,
3121 },
3122}
3123
3124#[derive(Debug, Clone)]
3125pub enum ThreadEvent {
3126 ShowError(ThreadError),
3127 StreamedCompletion,
3128 ReceivedTextChunk,
3129 NewRequest,
3130 StreamedAssistantText(MessageId, String),
3131 StreamedAssistantThinking(MessageId, String),
3132 StreamedToolUse {
3133 tool_use_id: LanguageModelToolUseId,
3134 ui_text: Arc<str>,
3135 input: serde_json::Value,
3136 },
3137 MissingToolUse {
3138 tool_use_id: LanguageModelToolUseId,
3139 ui_text: Arc<str>,
3140 },
3141 InvalidToolInput {
3142 tool_use_id: LanguageModelToolUseId,
3143 ui_text: Arc<str>,
3144 invalid_input_json: Arc<str>,
3145 },
3146 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3147 MessageAdded(MessageId),
3148 MessageEdited(MessageId),
3149 MessageDeleted(MessageId),
3150 SummaryGenerated,
3151 SummaryChanged,
3152 UsePendingTools {
3153 tool_uses: Vec<PendingToolUse>,
3154 },
3155 ToolFinished {
3156 #[allow(unused)]
3157 tool_use_id: LanguageModelToolUseId,
3158 /// The pending tool use that corresponds to this tool.
3159 pending_tool_use: Option<PendingToolUse>,
3160 },
3161 CheckpointChanged,
3162 ToolConfirmationNeeded,
3163 ToolUseLimitReached,
3164 CancelEditing,
3165 CompletionCanceled,
3166 ProfileChanged,
3167 RetriesFailed {
3168 message: SharedString,
3169 },
3170}
3171
3172impl EventEmitter<ThreadEvent> for Thread {}
3173
3174struct PendingCompletion {
3175 id: usize,
3176 queue_state: QueueState,
3177 _task: Task<()>,
3178}
3179
3180/// Resolves tool name conflicts by ensuring all tool names are unique.
3181///
3182/// When multiple tools have the same name, this function applies the following rules:
3183/// 1. Native tools always keep their original name
3184/// 2. Context server tools get prefixed with their server ID and an underscore
3185/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
3186/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
3187///
3188/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
3189fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
3190 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
3191 let mut tool_name = tool.name();
3192 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3193 tool_name
3194 }
3195
3196 const MAX_TOOL_NAME_LENGTH: usize = 64;
3197
3198 let mut duplicated_tool_names = HashSet::default();
3199 let mut seen_tool_names = HashSet::default();
3200 for tool in tools {
3201 let tool_name = resolve_tool_name(tool);
3202 if seen_tool_names.contains(&tool_name) {
3203 debug_assert!(
3204 tool.source() != assistant_tool::ToolSource::Native,
3205 "There are two built-in tools with the same name: {}",
3206 tool_name
3207 );
3208 duplicated_tool_names.insert(tool_name);
3209 } else {
3210 seen_tool_names.insert(tool_name);
3211 }
3212 }
3213
3214 if duplicated_tool_names.is_empty() {
3215 return tools
3216 .into_iter()
3217 .map(|tool| (resolve_tool_name(tool), tool.clone()))
3218 .collect();
3219 }
3220
3221 tools
3222 .into_iter()
3223 .filter_map(|tool| {
3224 let mut tool_name = resolve_tool_name(tool);
3225 if !duplicated_tool_names.contains(&tool_name) {
3226 return Some((tool_name, tool.clone()));
3227 }
3228 match tool.source() {
3229 assistant_tool::ToolSource::Native => {
3230 // Built-in tools always keep their original name
3231 Some((tool_name, tool.clone()))
3232 }
3233 assistant_tool::ToolSource::ContextServer { id } => {
3234 // Context server tools are prefixed with the context server ID, and truncated if necessary
3235 tool_name.insert(0, '_');
3236 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
3237 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
3238 let mut id = id.to_string();
3239 id.truncate(len);
3240 tool_name.insert_str(0, &id);
3241 } else {
3242 tool_name.insert_str(0, &id);
3243 }
3244
3245 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3246
3247 if seen_tool_names.contains(&tool_name) {
3248 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
3249 None
3250 } else {
3251 Some((tool_name, tool.clone()))
3252 }
3253 }
3254 }
3255 })
3256 .collect()
3257}
3258
3259#[cfg(test)]
3260mod tests {
3261 use super::*;
3262 use crate::{
3263 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3264 };
3265
3266 // Test-specific constants
3267 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3268 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3269 use assistant_tool::ToolRegistry;
3270 use futures::StreamExt;
3271 use futures::future::BoxFuture;
3272 use futures::stream::BoxStream;
3273 use gpui::TestAppContext;
3274 use icons::IconName;
3275 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3276 use language_model::{
3277 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3278 LanguageModelProviderName, LanguageModelToolChoice,
3279 };
3280 use parking_lot::Mutex;
3281 use project::{FakeFs, Project};
3282 use prompt_store::PromptBuilder;
3283 use serde_json::json;
3284 use settings::{Settings, SettingsStore};
3285 use std::sync::Arc;
3286 use std::time::Duration;
3287 use theme::ThemeSettings;
3288 use util::path;
3289 use workspace::Workspace;
3290
3291 #[gpui::test]
3292 async fn test_message_with_context(cx: &mut TestAppContext) {
3293 init_test_settings(cx);
3294
3295 let project = create_test_project(
3296 cx,
3297 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3298 )
3299 .await;
3300
3301 let (_workspace, _thread_store, thread, context_store, model) =
3302 setup_test_environment(cx, project.clone()).await;
3303
3304 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3305 .await
3306 .unwrap();
3307
3308 let context =
3309 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3310 let loaded_context = cx
3311 .update(|cx| load_context(vec![context], &project, &None, cx))
3312 .await;
3313
3314 // Insert user message with context
3315 let message_id = thread.update(cx, |thread, cx| {
3316 thread.insert_user_message(
3317 "Please explain this code",
3318 loaded_context,
3319 None,
3320 Vec::new(),
3321 cx,
3322 )
3323 });
3324
3325 // Check content and context in message object
3326 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3327
3328 // Use different path format strings based on platform for the test
3329 #[cfg(windows)]
3330 let path_part = r"test\code.rs";
3331 #[cfg(not(windows))]
3332 let path_part = "test/code.rs";
3333
3334 let expected_context = format!(
3335 r#"
3336<context>
3337The following items were attached by the user. They are up-to-date and don't need to be re-read.
3338
3339<files>
3340```rs {path_part}
3341fn main() {{
3342 println!("Hello, world!");
3343}}
3344```
3345</files>
3346</context>
3347"#
3348 );
3349
3350 assert_eq!(message.role, Role::User);
3351 assert_eq!(message.segments.len(), 1);
3352 assert_eq!(
3353 message.segments[0],
3354 MessageSegment::Text("Please explain this code".to_string())
3355 );
3356 assert_eq!(message.loaded_context.text, expected_context);
3357
3358 // Check message in request
3359 let request = thread.update(cx, |thread, cx| {
3360 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3361 });
3362
3363 assert_eq!(request.messages.len(), 2);
3364 let expected_full_message = format!("{}Please explain this code", expected_context);
3365 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3366 }
3367
3368 #[gpui::test]
3369 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3370 init_test_settings(cx);
3371
3372 let project = create_test_project(
3373 cx,
3374 json!({
3375 "file1.rs": "fn function1() {}\n",
3376 "file2.rs": "fn function2() {}\n",
3377 "file3.rs": "fn function3() {}\n",
3378 "file4.rs": "fn function4() {}\n",
3379 }),
3380 )
3381 .await;
3382
3383 let (_, _thread_store, thread, context_store, model) =
3384 setup_test_environment(cx, project.clone()).await;
3385
3386 // First message with context 1
3387 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3388 .await
3389 .unwrap();
3390 let new_contexts = context_store.update(cx, |store, cx| {
3391 store.new_context_for_thread(thread.read(cx), None)
3392 });
3393 assert_eq!(new_contexts.len(), 1);
3394 let loaded_context = cx
3395 .update(|cx| load_context(new_contexts, &project, &None, cx))
3396 .await;
3397 let message1_id = thread.update(cx, |thread, cx| {
3398 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3399 });
3400
3401 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3402 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3403 .await
3404 .unwrap();
3405 let new_contexts = context_store.update(cx, |store, cx| {
3406 store.new_context_for_thread(thread.read(cx), None)
3407 });
3408 assert_eq!(new_contexts.len(), 1);
3409 let loaded_context = cx
3410 .update(|cx| load_context(new_contexts, &project, &None, cx))
3411 .await;
3412 let message2_id = thread.update(cx, |thread, cx| {
3413 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3414 });
3415
3416 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3417 //
3418 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3419 .await
3420 .unwrap();
3421 let new_contexts = context_store.update(cx, |store, cx| {
3422 store.new_context_for_thread(thread.read(cx), None)
3423 });
3424 assert_eq!(new_contexts.len(), 1);
3425 let loaded_context = cx
3426 .update(|cx| load_context(new_contexts, &project, &None, cx))
3427 .await;
3428 let message3_id = thread.update(cx, |thread, cx| {
3429 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3430 });
3431
3432 // Check what contexts are included in each message
3433 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3434 (
3435 thread.message(message1_id).unwrap().clone(),
3436 thread.message(message2_id).unwrap().clone(),
3437 thread.message(message3_id).unwrap().clone(),
3438 )
3439 });
3440
3441 // First message should include context 1
3442 assert!(message1.loaded_context.text.contains("file1.rs"));
3443
3444 // Second message should include only context 2 (not 1)
3445 assert!(!message2.loaded_context.text.contains("file1.rs"));
3446 assert!(message2.loaded_context.text.contains("file2.rs"));
3447
3448 // Third message should include only context 3 (not 1 or 2)
3449 assert!(!message3.loaded_context.text.contains("file1.rs"));
3450 assert!(!message3.loaded_context.text.contains("file2.rs"));
3451 assert!(message3.loaded_context.text.contains("file3.rs"));
3452
3453 // Check entire request to make sure all contexts are properly included
3454 let request = thread.update(cx, |thread, cx| {
3455 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3456 });
3457
3458 // The request should contain all 3 messages
3459 assert_eq!(request.messages.len(), 4);
3460
3461 // Check that the contexts are properly formatted in each message
3462 assert!(request.messages[1].string_contents().contains("file1.rs"));
3463 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3464 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3465
3466 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3467 assert!(request.messages[2].string_contents().contains("file2.rs"));
3468 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3469
3470 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3471 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3472 assert!(request.messages[3].string_contents().contains("file3.rs"));
3473
3474 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3475 .await
3476 .unwrap();
3477 let new_contexts = context_store.update(cx, |store, cx| {
3478 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3479 });
3480 assert_eq!(new_contexts.len(), 3);
3481 let loaded_context = cx
3482 .update(|cx| load_context(new_contexts, &project, &None, cx))
3483 .await
3484 .loaded_context;
3485
3486 assert!(!loaded_context.text.contains("file1.rs"));
3487 assert!(loaded_context.text.contains("file2.rs"));
3488 assert!(loaded_context.text.contains("file3.rs"));
3489 assert!(loaded_context.text.contains("file4.rs"));
3490
3491 let new_contexts = context_store.update(cx, |store, cx| {
3492 // Remove file4.rs
3493 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3494 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3495 });
3496 assert_eq!(new_contexts.len(), 2);
3497 let loaded_context = cx
3498 .update(|cx| load_context(new_contexts, &project, &None, cx))
3499 .await
3500 .loaded_context;
3501
3502 assert!(!loaded_context.text.contains("file1.rs"));
3503 assert!(loaded_context.text.contains("file2.rs"));
3504 assert!(loaded_context.text.contains("file3.rs"));
3505 assert!(!loaded_context.text.contains("file4.rs"));
3506
3507 let new_contexts = context_store.update(cx, |store, cx| {
3508 // Remove file3.rs
3509 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3510 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3511 });
3512 assert_eq!(new_contexts.len(), 1);
3513 let loaded_context = cx
3514 .update(|cx| load_context(new_contexts, &project, &None, cx))
3515 .await
3516 .loaded_context;
3517
3518 assert!(!loaded_context.text.contains("file1.rs"));
3519 assert!(loaded_context.text.contains("file2.rs"));
3520 assert!(!loaded_context.text.contains("file3.rs"));
3521 assert!(!loaded_context.text.contains("file4.rs"));
3522 }
3523
3524 #[gpui::test]
3525 async fn test_message_without_files(cx: &mut TestAppContext) {
3526 init_test_settings(cx);
3527
3528 let project = create_test_project(
3529 cx,
3530 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3531 )
3532 .await;
3533
3534 let (_, _thread_store, thread, _context_store, model) =
3535 setup_test_environment(cx, project.clone()).await;
3536
3537 // Insert user message without any context (empty context vector)
3538 let message_id = thread.update(cx, |thread, cx| {
3539 thread.insert_user_message(
3540 "What is the best way to learn Rust?",
3541 ContextLoadResult::default(),
3542 None,
3543 Vec::new(),
3544 cx,
3545 )
3546 });
3547
3548 // Check content and context in message object
3549 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3550
3551 // Context should be empty when no files are included
3552 assert_eq!(message.role, Role::User);
3553 assert_eq!(message.segments.len(), 1);
3554 assert_eq!(
3555 message.segments[0],
3556 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3557 );
3558 assert_eq!(message.loaded_context.text, "");
3559
3560 // Check message in request
3561 let request = thread.update(cx, |thread, cx| {
3562 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3563 });
3564
3565 assert_eq!(request.messages.len(), 2);
3566 assert_eq!(
3567 request.messages[1].string_contents(),
3568 "What is the best way to learn Rust?"
3569 );
3570
3571 // Add second message, also without context
3572 let message2_id = thread.update(cx, |thread, cx| {
3573 thread.insert_user_message(
3574 "Are there any good books?",
3575 ContextLoadResult::default(),
3576 None,
3577 Vec::new(),
3578 cx,
3579 )
3580 });
3581
3582 let message2 =
3583 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3584 assert_eq!(message2.loaded_context.text, "");
3585
3586 // Check that both messages appear in the request
3587 let request = thread.update(cx, |thread, cx| {
3588 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3589 });
3590
3591 assert_eq!(request.messages.len(), 3);
3592 assert_eq!(
3593 request.messages[1].string_contents(),
3594 "What is the best way to learn Rust?"
3595 );
3596 assert_eq!(
3597 request.messages[2].string_contents(),
3598 "Are there any good books?"
3599 );
3600 }
3601
3602 #[gpui::test]
3603 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3604 init_test_settings(cx);
3605
3606 let project = create_test_project(
3607 cx,
3608 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3609 )
3610 .await;
3611
3612 let (_workspace, thread_store, thread, _context_store, _model) =
3613 setup_test_environment(cx, project.clone()).await;
3614
3615 // Check that we are starting with the default profile
3616 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3617 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3618 assert_eq!(
3619 profile,
3620 AgentProfile::new(AgentProfileId::default(), tool_set)
3621 );
3622 }
3623
3624 #[gpui::test]
3625 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3626 init_test_settings(cx);
3627
3628 let project = create_test_project(
3629 cx,
3630 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3631 )
3632 .await;
3633
3634 let (_workspace, thread_store, thread, _context_store, _model) =
3635 setup_test_environment(cx, project.clone()).await;
3636
3637 // Profile gets serialized with default values
3638 let serialized = thread
3639 .update(cx, |thread, cx| thread.serialize(cx))
3640 .await
3641 .unwrap();
3642
3643 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3644
3645 let deserialized = cx.update(|cx| {
3646 thread.update(cx, |thread, cx| {
3647 Thread::deserialize(
3648 thread.id.clone(),
3649 serialized,
3650 thread.project.clone(),
3651 thread.tools.clone(),
3652 thread.prompt_builder.clone(),
3653 thread.project_context.clone(),
3654 None,
3655 cx,
3656 )
3657 })
3658 });
3659 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3660
3661 assert_eq!(
3662 deserialized.profile,
3663 AgentProfile::new(AgentProfileId::default(), tool_set)
3664 );
3665 }
3666
3667 #[gpui::test]
3668 async fn test_temperature_setting(cx: &mut TestAppContext) {
3669 init_test_settings(cx);
3670
3671 let project = create_test_project(
3672 cx,
3673 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3674 )
3675 .await;
3676
3677 let (_workspace, _thread_store, thread, _context_store, model) =
3678 setup_test_environment(cx, project.clone()).await;
3679
3680 // Both model and provider
3681 cx.update(|cx| {
3682 AgentSettings::override_global(
3683 AgentSettings {
3684 model_parameters: vec![LanguageModelParameters {
3685 provider: Some(model.provider_id().0.to_string().into()),
3686 model: Some(model.id().0.clone()),
3687 temperature: Some(0.66),
3688 }],
3689 ..AgentSettings::get_global(cx).clone()
3690 },
3691 cx,
3692 );
3693 });
3694
3695 let request = thread.update(cx, |thread, cx| {
3696 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3697 });
3698 assert_eq!(request.temperature, Some(0.66));
3699
3700 // Only model
3701 cx.update(|cx| {
3702 AgentSettings::override_global(
3703 AgentSettings {
3704 model_parameters: vec![LanguageModelParameters {
3705 provider: None,
3706 model: Some(model.id().0.clone()),
3707 temperature: Some(0.66),
3708 }],
3709 ..AgentSettings::get_global(cx).clone()
3710 },
3711 cx,
3712 );
3713 });
3714
3715 let request = thread.update(cx, |thread, cx| {
3716 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3717 });
3718 assert_eq!(request.temperature, Some(0.66));
3719
3720 // Only provider
3721 cx.update(|cx| {
3722 AgentSettings::override_global(
3723 AgentSettings {
3724 model_parameters: vec![LanguageModelParameters {
3725 provider: Some(model.provider_id().0.to_string().into()),
3726 model: None,
3727 temperature: Some(0.66),
3728 }],
3729 ..AgentSettings::get_global(cx).clone()
3730 },
3731 cx,
3732 );
3733 });
3734
3735 let request = thread.update(cx, |thread, cx| {
3736 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3737 });
3738 assert_eq!(request.temperature, Some(0.66));
3739
3740 // Same model name, different provider
3741 cx.update(|cx| {
3742 AgentSettings::override_global(
3743 AgentSettings {
3744 model_parameters: vec![LanguageModelParameters {
3745 provider: Some("anthropic".into()),
3746 model: Some(model.id().0.clone()),
3747 temperature: Some(0.66),
3748 }],
3749 ..AgentSettings::get_global(cx).clone()
3750 },
3751 cx,
3752 );
3753 });
3754
3755 let request = thread.update(cx, |thread, cx| {
3756 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3757 });
3758 assert_eq!(request.temperature, None);
3759 }
3760
3761 #[gpui::test]
3762 async fn test_thread_summary(cx: &mut TestAppContext) {
3763 init_test_settings(cx);
3764
3765 let project = create_test_project(cx, json!({})).await;
3766
3767 let (_, _thread_store, thread, _context_store, model) =
3768 setup_test_environment(cx, project.clone()).await;
3769
3770 // Initial state should be pending
3771 thread.read_with(cx, |thread, _| {
3772 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3773 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3774 });
3775
3776 // Manually setting the summary should not be allowed in this state
3777 thread.update(cx, |thread, cx| {
3778 thread.set_summary("This should not work", cx);
3779 });
3780
3781 thread.read_with(cx, |thread, _| {
3782 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3783 });
3784
3785 // Send a message
3786 thread.update(cx, |thread, cx| {
3787 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3788 thread.send_to_model(
3789 model.clone(),
3790 CompletionIntent::ThreadSummarization,
3791 None,
3792 cx,
3793 );
3794 });
3795
3796 let fake_model = model.as_fake();
3797 simulate_successful_response(&fake_model, cx);
3798
3799 // Should start generating summary when there are >= 2 messages
3800 thread.read_with(cx, |thread, _| {
3801 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3802 });
3803
3804 // Should not be able to set the summary while generating
3805 thread.update(cx, |thread, cx| {
3806 thread.set_summary("This should not work either", cx);
3807 });
3808
3809 thread.read_with(cx, |thread, _| {
3810 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3811 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3812 });
3813
3814 cx.run_until_parked();
3815 fake_model.stream_last_completion_response("Brief");
3816 fake_model.stream_last_completion_response(" Introduction");
3817 fake_model.end_last_completion_stream();
3818 cx.run_until_parked();
3819
3820 // Summary should be set
3821 thread.read_with(cx, |thread, _| {
3822 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3823 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3824 });
3825
3826 // Now we should be able to set a summary
3827 thread.update(cx, |thread, cx| {
3828 thread.set_summary("Brief Intro", cx);
3829 });
3830
3831 thread.read_with(cx, |thread, _| {
3832 assert_eq!(thread.summary().or_default(), "Brief Intro");
3833 });
3834
3835 // Test setting an empty summary (should default to DEFAULT)
3836 thread.update(cx, |thread, cx| {
3837 thread.set_summary("", cx);
3838 });
3839
3840 thread.read_with(cx, |thread, _| {
3841 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3842 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3843 });
3844 }
3845
3846 #[gpui::test]
3847 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3848 init_test_settings(cx);
3849
3850 let project = create_test_project(cx, json!({})).await;
3851
3852 let (_, _thread_store, thread, _context_store, model) =
3853 setup_test_environment(cx, project.clone()).await;
3854
3855 test_summarize_error(&model, &thread, cx);
3856
3857 // Now we should be able to set a summary
3858 thread.update(cx, |thread, cx| {
3859 thread.set_summary("Brief Intro", cx);
3860 });
3861
3862 thread.read_with(cx, |thread, _| {
3863 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3864 assert_eq!(thread.summary().or_default(), "Brief Intro");
3865 });
3866 }
3867
3868 #[gpui::test]
3869 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3870 init_test_settings(cx);
3871
3872 let project = create_test_project(cx, json!({})).await;
3873
3874 let (_, _thread_store, thread, _context_store, model) =
3875 setup_test_environment(cx, project.clone()).await;
3876
3877 test_summarize_error(&model, &thread, cx);
3878
3879 // Sending another message should not trigger another summarize request
3880 thread.update(cx, |thread, cx| {
3881 thread.insert_user_message(
3882 "How are you?",
3883 ContextLoadResult::default(),
3884 None,
3885 vec![],
3886 cx,
3887 );
3888 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3889 });
3890
3891 let fake_model = model.as_fake();
3892 simulate_successful_response(&fake_model, cx);
3893
3894 thread.read_with(cx, |thread, _| {
3895 // State is still Error, not Generating
3896 assert!(matches!(thread.summary(), ThreadSummary::Error));
3897 });
3898
3899 // But the summarize request can be invoked manually
3900 thread.update(cx, |thread, cx| {
3901 thread.summarize(cx);
3902 });
3903
3904 thread.read_with(cx, |thread, _| {
3905 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3906 });
3907
3908 cx.run_until_parked();
3909 fake_model.stream_last_completion_response("A successful summary");
3910 fake_model.end_last_completion_stream();
3911 cx.run_until_parked();
3912
3913 thread.read_with(cx, |thread, _| {
3914 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3915 assert_eq!(thread.summary().or_default(), "A successful summary");
3916 });
3917 }
3918
3919 #[gpui::test]
3920 fn test_resolve_tool_name_conflicts() {
3921 use assistant_tool::{Tool, ToolSource};
3922
3923 assert_resolve_tool_name_conflicts(
3924 vec![
3925 TestTool::new("tool1", ToolSource::Native),
3926 TestTool::new("tool2", ToolSource::Native),
3927 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3928 ],
3929 vec!["tool1", "tool2", "tool3"],
3930 );
3931
3932 assert_resolve_tool_name_conflicts(
3933 vec![
3934 TestTool::new("tool1", ToolSource::Native),
3935 TestTool::new("tool2", ToolSource::Native),
3936 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3937 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3938 ],
3939 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3940 );
3941
3942 assert_resolve_tool_name_conflicts(
3943 vec![
3944 TestTool::new("tool1", ToolSource::Native),
3945 TestTool::new("tool2", ToolSource::Native),
3946 TestTool::new("tool3", ToolSource::Native),
3947 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3948 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3949 ],
3950 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3951 );
3952
3953 // Test that tool with very long name is always truncated
3954 assert_resolve_tool_name_conflicts(
3955 vec![TestTool::new(
3956 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3957 ToolSource::Native,
3958 )],
3959 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3960 );
3961
3962 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3963 assert_resolve_tool_name_conflicts(
3964 vec![
3965 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3966 TestTool::new(
3967 "tool-with-very-very-very-long-name",
3968 ToolSource::ContextServer {
3969 id: "mcp-with-very-very-very-long-name".into(),
3970 },
3971 ),
3972 ],
3973 vec![
3974 "tool-with-very-very-very-long-name",
3975 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3976 ],
3977 );
3978
3979 fn assert_resolve_tool_name_conflicts(
3980 tools: Vec<TestTool>,
3981 expected: Vec<impl Into<String>>,
3982 ) {
3983 let tools: Vec<Arc<dyn Tool>> = tools
3984 .into_iter()
3985 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3986 .collect();
3987 let tools = resolve_tool_name_conflicts(&tools);
3988 assert_eq!(tools.len(), expected.len());
3989 for (i, expected_name) in expected.into_iter().enumerate() {
3990 let expected_name = expected_name.into();
3991 let actual_name = &tools[i].0;
3992 assert_eq!(
3993 actual_name, &expected_name,
3994 "Expected '{}' got '{}' at index {}",
3995 expected_name, actual_name, i
3996 );
3997 }
3998 }
3999
4000 struct TestTool {
4001 name: String,
4002 source: ToolSource,
4003 }
4004
4005 impl TestTool {
4006 fn new(name: impl Into<String>, source: ToolSource) -> Self {
4007 Self {
4008 name: name.into(),
4009 source,
4010 }
4011 }
4012 }
4013
4014 impl Tool for TestTool {
4015 fn name(&self) -> String {
4016 self.name.clone()
4017 }
4018
4019 fn icon(&self) -> IconName {
4020 IconName::Ai
4021 }
4022
4023 fn may_perform_edits(&self) -> bool {
4024 false
4025 }
4026
4027 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
4028 true
4029 }
4030
4031 fn source(&self) -> ToolSource {
4032 self.source.clone()
4033 }
4034
4035 fn description(&self) -> String {
4036 "Test tool".to_string()
4037 }
4038
4039 fn ui_text(&self, _input: &serde_json::Value) -> String {
4040 "Test tool".to_string()
4041 }
4042
4043 fn run(
4044 self: Arc<Self>,
4045 _input: serde_json::Value,
4046 _request: Arc<LanguageModelRequest>,
4047 _project: Entity<Project>,
4048 _action_log: Entity<ActionLog>,
4049 _model: Arc<dyn LanguageModel>,
4050 _window: Option<AnyWindowHandle>,
4051 _cx: &mut App,
4052 ) -> assistant_tool::ToolResult {
4053 assistant_tool::ToolResult {
4054 output: Task::ready(Err(anyhow::anyhow!("No content"))),
4055 card: None,
4056 }
4057 }
4058 }
4059 }
4060
4061 // Helper to create a model that returns errors
4062 enum TestError {
4063 Overloaded,
4064 InternalServerError,
4065 }
4066
4067 struct ErrorInjector {
4068 inner: Arc<FakeLanguageModel>,
4069 error_type: TestError,
4070 }
4071
4072 impl ErrorInjector {
4073 fn new(error_type: TestError) -> Self {
4074 Self {
4075 inner: Arc::new(FakeLanguageModel::default()),
4076 error_type,
4077 }
4078 }
4079 }
4080
4081 impl LanguageModel for ErrorInjector {
4082 fn id(&self) -> LanguageModelId {
4083 self.inner.id()
4084 }
4085
4086 fn name(&self) -> LanguageModelName {
4087 self.inner.name()
4088 }
4089
4090 fn provider_id(&self) -> LanguageModelProviderId {
4091 self.inner.provider_id()
4092 }
4093
4094 fn provider_name(&self) -> LanguageModelProviderName {
4095 self.inner.provider_name()
4096 }
4097
4098 fn supports_tools(&self) -> bool {
4099 self.inner.supports_tools()
4100 }
4101
4102 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4103 self.inner.supports_tool_choice(choice)
4104 }
4105
4106 fn supports_images(&self) -> bool {
4107 self.inner.supports_images()
4108 }
4109
4110 fn telemetry_id(&self) -> String {
4111 self.inner.telemetry_id()
4112 }
4113
4114 fn max_token_count(&self) -> u64 {
4115 self.inner.max_token_count()
4116 }
4117
4118 fn count_tokens(
4119 &self,
4120 request: LanguageModelRequest,
4121 cx: &App,
4122 ) -> BoxFuture<'static, Result<u64>> {
4123 self.inner.count_tokens(request, cx)
4124 }
4125
4126 fn stream_completion(
4127 &self,
4128 _request: LanguageModelRequest,
4129 _cx: &AsyncApp,
4130 ) -> BoxFuture<
4131 'static,
4132 Result<
4133 BoxStream<
4134 'static,
4135 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4136 >,
4137 LanguageModelCompletionError,
4138 >,
4139 > {
4140 let error = match self.error_type {
4141 TestError::Overloaded => LanguageModelCompletionError::Overloaded,
4142 TestError::InternalServerError => {
4143 LanguageModelCompletionError::ApiInternalServerError
4144 }
4145 };
4146 async move {
4147 let stream = futures::stream::once(async move { Err(error) });
4148 Ok(stream.boxed())
4149 }
4150 .boxed()
4151 }
4152
4153 fn as_fake(&self) -> &FakeLanguageModel {
4154 &self.inner
4155 }
4156 }
4157
4158 #[gpui::test]
4159 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4160 init_test_settings(cx);
4161
4162 let project = create_test_project(cx, json!({})).await;
4163 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4164
4165 // Create model that returns overloaded error
4166 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4167
4168 // Insert a user message
4169 thread.update(cx, |thread, cx| {
4170 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4171 });
4172
4173 // Start completion
4174 thread.update(cx, |thread, cx| {
4175 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4176 });
4177
4178 cx.run_until_parked();
4179
4180 thread.read_with(cx, |thread, _| {
4181 assert!(thread.retry_state.is_some(), "Should have retry state");
4182 let retry_state = thread.retry_state.as_ref().unwrap();
4183 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4184 assert_eq!(
4185 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4186 "Should have default max attempts"
4187 );
4188 });
4189
4190 // Check that a retry message was added
4191 thread.read_with(cx, |thread, _| {
4192 let mut messages = thread.messages();
4193 assert!(
4194 messages.any(|msg| {
4195 msg.role == Role::System
4196 && msg.ui_only
4197 && msg.segments.iter().any(|seg| {
4198 if let MessageSegment::Text(text) = seg {
4199 text.contains("overloaded")
4200 && text
4201 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4202 } else {
4203 false
4204 }
4205 })
4206 }),
4207 "Should have added a system retry message"
4208 );
4209 });
4210
4211 let retry_count = thread.update(cx, |thread, _| {
4212 thread
4213 .messages
4214 .iter()
4215 .filter(|m| {
4216 m.ui_only
4217 && m.segments.iter().any(|s| {
4218 if let MessageSegment::Text(text) = s {
4219 text.contains("Retrying") && text.contains("seconds")
4220 } else {
4221 false
4222 }
4223 })
4224 })
4225 .count()
4226 });
4227
4228 assert_eq!(retry_count, 1, "Should have one retry message");
4229 }
4230
4231 #[gpui::test]
4232 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4233 init_test_settings(cx);
4234
4235 let project = create_test_project(cx, json!({})).await;
4236 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4237
4238 // Create model that returns internal server error
4239 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4240
4241 // Insert a user message
4242 thread.update(cx, |thread, cx| {
4243 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4244 });
4245
4246 // Start completion
4247 thread.update(cx, |thread, cx| {
4248 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4249 });
4250
4251 cx.run_until_parked();
4252
4253 // Check retry state on thread
4254 thread.read_with(cx, |thread, _| {
4255 assert!(thread.retry_state.is_some(), "Should have retry state");
4256 let retry_state = thread.retry_state.as_ref().unwrap();
4257 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4258 assert_eq!(
4259 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4260 "Should have correct max attempts"
4261 );
4262 });
4263
4264 // Check that a retry message was added with provider name
4265 thread.read_with(cx, |thread, _| {
4266 let mut messages = thread.messages();
4267 assert!(
4268 messages.any(|msg| {
4269 msg.role == Role::System
4270 && msg.ui_only
4271 && msg.segments.iter().any(|seg| {
4272 if let MessageSegment::Text(text) = seg {
4273 text.contains("internal")
4274 && text.contains("Fake")
4275 && text
4276 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4277 } else {
4278 false
4279 }
4280 })
4281 }),
4282 "Should have added a system retry message with provider name"
4283 );
4284 });
4285
4286 // Count retry messages
4287 let retry_count = thread.update(cx, |thread, _| {
4288 thread
4289 .messages
4290 .iter()
4291 .filter(|m| {
4292 m.ui_only
4293 && m.segments.iter().any(|s| {
4294 if let MessageSegment::Text(text) = s {
4295 text.contains("Retrying") && text.contains("seconds")
4296 } else {
4297 false
4298 }
4299 })
4300 })
4301 .count()
4302 });
4303
4304 assert_eq!(retry_count, 1, "Should have one retry message");
4305 }
4306
4307 #[gpui::test]
4308 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4309 init_test_settings(cx);
4310
4311 let project = create_test_project(cx, json!({})).await;
4312 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4313
4314 // Create model that returns overloaded error
4315 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4316
4317 // Insert a user message
4318 thread.update(cx, |thread, cx| {
4319 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4320 });
4321
4322 // Track retry events and completion count
4323 // Track completion events
4324 let completion_count = Arc::new(Mutex::new(0));
4325 let completion_count_clone = completion_count.clone();
4326
4327 let _subscription = thread.update(cx, |_, cx| {
4328 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4329 if let ThreadEvent::NewRequest = event {
4330 *completion_count_clone.lock() += 1;
4331 }
4332 })
4333 });
4334
4335 // First attempt
4336 thread.update(cx, |thread, cx| {
4337 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4338 });
4339 cx.run_until_parked();
4340
4341 // Should have scheduled first retry - count retry messages
4342 let retry_count = thread.update(cx, |thread, _| {
4343 thread
4344 .messages
4345 .iter()
4346 .filter(|m| {
4347 m.ui_only
4348 && m.segments.iter().any(|s| {
4349 if let MessageSegment::Text(text) = s {
4350 text.contains("Retrying") && text.contains("seconds")
4351 } else {
4352 false
4353 }
4354 })
4355 })
4356 .count()
4357 });
4358 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4359
4360 // Check retry state
4361 thread.read_with(cx, |thread, _| {
4362 assert!(thread.retry_state.is_some(), "Should have retry state");
4363 let retry_state = thread.retry_state.as_ref().unwrap();
4364 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4365 });
4366
4367 // Advance clock for first retry
4368 cx.executor()
4369 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4370 cx.run_until_parked();
4371
4372 // Should have scheduled second retry - count retry messages
4373 let retry_count = thread.update(cx, |thread, _| {
4374 thread
4375 .messages
4376 .iter()
4377 .filter(|m| {
4378 m.ui_only
4379 && m.segments.iter().any(|s| {
4380 if let MessageSegment::Text(text) = s {
4381 text.contains("Retrying") && text.contains("seconds")
4382 } else {
4383 false
4384 }
4385 })
4386 })
4387 .count()
4388 });
4389 assert_eq!(retry_count, 2, "Should have scheduled second retry");
4390
4391 // Check retry state updated
4392 thread.read_with(cx, |thread, _| {
4393 assert!(thread.retry_state.is_some(), "Should have retry state");
4394 let retry_state = thread.retry_state.as_ref().unwrap();
4395 assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4396 assert_eq!(
4397 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4398 "Should have correct max attempts"
4399 );
4400 });
4401
4402 // Advance clock for second retry (exponential backoff)
4403 cx.executor()
4404 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4405 cx.run_until_parked();
4406
4407 // Should have scheduled third retry
4408 // Count all retry messages now
4409 let retry_count = thread.update(cx, |thread, _| {
4410 thread
4411 .messages
4412 .iter()
4413 .filter(|m| {
4414 m.ui_only
4415 && m.segments.iter().any(|s| {
4416 if let MessageSegment::Text(text) = s {
4417 text.contains("Retrying") && text.contains("seconds")
4418 } else {
4419 false
4420 }
4421 })
4422 })
4423 .count()
4424 });
4425 assert_eq!(
4426 retry_count, MAX_RETRY_ATTEMPTS as usize,
4427 "Should have scheduled third retry"
4428 );
4429
4430 // Check retry state updated
4431 thread.read_with(cx, |thread, _| {
4432 assert!(thread.retry_state.is_some(), "Should have retry state");
4433 let retry_state = thread.retry_state.as_ref().unwrap();
4434 assert_eq!(
4435 retry_state.attempt, MAX_RETRY_ATTEMPTS,
4436 "Should be at max retry attempt"
4437 );
4438 assert_eq!(
4439 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4440 "Should have correct max attempts"
4441 );
4442 });
4443
4444 // Advance clock for third retry (exponential backoff)
4445 cx.executor()
4446 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4447 cx.run_until_parked();
4448
4449 // No more retries should be scheduled after clock was advanced.
4450 let retry_count = thread.update(cx, |thread, _| {
4451 thread
4452 .messages
4453 .iter()
4454 .filter(|m| {
4455 m.ui_only
4456 && m.segments.iter().any(|s| {
4457 if let MessageSegment::Text(text) = s {
4458 text.contains("Retrying") && text.contains("seconds")
4459 } else {
4460 false
4461 }
4462 })
4463 })
4464 .count()
4465 });
4466 assert_eq!(
4467 retry_count, MAX_RETRY_ATTEMPTS as usize,
4468 "Should not exceed max retries"
4469 );
4470
4471 // Final completion count should be initial + max retries
4472 assert_eq!(
4473 *completion_count.lock(),
4474 (MAX_RETRY_ATTEMPTS + 1) as usize,
4475 "Should have made initial + max retry attempts"
4476 );
4477 }
4478
4479 #[gpui::test]
4480 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4481 init_test_settings(cx);
4482
4483 let project = create_test_project(cx, json!({})).await;
4484 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4485
4486 // Create model that returns overloaded error
4487 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4488
4489 // Insert a user message
4490 thread.update(cx, |thread, cx| {
4491 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4492 });
4493
4494 // Track events
4495 let retries_failed = Arc::new(Mutex::new(false));
4496 let retries_failed_clone = retries_failed.clone();
4497
4498 let _subscription = thread.update(cx, |_, cx| {
4499 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4500 if let ThreadEvent::RetriesFailed { .. } = event {
4501 *retries_failed_clone.lock() = true;
4502 }
4503 })
4504 });
4505
4506 // Start initial completion
4507 thread.update(cx, |thread, cx| {
4508 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4509 });
4510 cx.run_until_parked();
4511
4512 // Advance through all retries
4513 for i in 0..MAX_RETRY_ATTEMPTS {
4514 let delay = if i == 0 {
4515 BASE_RETRY_DELAY_SECS
4516 } else {
4517 BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4518 };
4519 cx.executor().advance_clock(Duration::from_secs(delay));
4520 cx.run_until_parked();
4521 }
4522
4523 // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4524 // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4525 let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4526 cx.executor()
4527 .advance_clock(Duration::from_secs(final_delay));
4528 cx.run_until_parked();
4529
4530 let retry_count = thread.update(cx, |thread, _| {
4531 thread
4532 .messages
4533 .iter()
4534 .filter(|m| {
4535 m.ui_only
4536 && m.segments.iter().any(|s| {
4537 if let MessageSegment::Text(text) = s {
4538 text.contains("Retrying") && text.contains("seconds")
4539 } else {
4540 false
4541 }
4542 })
4543 })
4544 .count()
4545 });
4546
4547 // After max retries, should emit RetriesFailed event
4548 assert_eq!(
4549 retry_count, MAX_RETRY_ATTEMPTS as usize,
4550 "Should have attempted max retries"
4551 );
4552 assert!(
4553 *retries_failed.lock(),
4554 "Should emit RetriesFailed event after max retries exceeded"
4555 );
4556
4557 // Retry state should be cleared
4558 thread.read_with(cx, |thread, _| {
4559 assert!(
4560 thread.retry_state.is_none(),
4561 "Retry state should be cleared after max retries"
4562 );
4563
4564 // Verify we have the expected number of retry messages
4565 let retry_messages = thread
4566 .messages
4567 .iter()
4568 .filter(|msg| msg.ui_only && msg.role == Role::System)
4569 .count();
4570 assert_eq!(
4571 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4572 "Should have one retry message per attempt"
4573 );
4574 });
4575 }
4576
4577 #[gpui::test]
4578 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4579 init_test_settings(cx);
4580
4581 let project = create_test_project(cx, json!({})).await;
4582 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4583
4584 // We'll use a wrapper to switch behavior after first failure
4585 struct RetryTestModel {
4586 inner: Arc<FakeLanguageModel>,
4587 failed_once: Arc<Mutex<bool>>,
4588 }
4589
4590 impl LanguageModel for RetryTestModel {
4591 fn id(&self) -> LanguageModelId {
4592 self.inner.id()
4593 }
4594
4595 fn name(&self) -> LanguageModelName {
4596 self.inner.name()
4597 }
4598
4599 fn provider_id(&self) -> LanguageModelProviderId {
4600 self.inner.provider_id()
4601 }
4602
4603 fn provider_name(&self) -> LanguageModelProviderName {
4604 self.inner.provider_name()
4605 }
4606
4607 fn supports_tools(&self) -> bool {
4608 self.inner.supports_tools()
4609 }
4610
4611 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4612 self.inner.supports_tool_choice(choice)
4613 }
4614
4615 fn supports_images(&self) -> bool {
4616 self.inner.supports_images()
4617 }
4618
4619 fn telemetry_id(&self) -> String {
4620 self.inner.telemetry_id()
4621 }
4622
4623 fn max_token_count(&self) -> u64 {
4624 self.inner.max_token_count()
4625 }
4626
4627 fn count_tokens(
4628 &self,
4629 request: LanguageModelRequest,
4630 cx: &App,
4631 ) -> BoxFuture<'static, Result<u64>> {
4632 self.inner.count_tokens(request, cx)
4633 }
4634
4635 fn stream_completion(
4636 &self,
4637 request: LanguageModelRequest,
4638 cx: &AsyncApp,
4639 ) -> BoxFuture<
4640 'static,
4641 Result<
4642 BoxStream<
4643 'static,
4644 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4645 >,
4646 LanguageModelCompletionError,
4647 >,
4648 > {
4649 if !*self.failed_once.lock() {
4650 *self.failed_once.lock() = true;
4651 // Return error on first attempt
4652 let stream = futures::stream::once(async move {
4653 Err(LanguageModelCompletionError::Overloaded)
4654 });
4655 async move { Ok(stream.boxed()) }.boxed()
4656 } else {
4657 // Succeed on retry
4658 self.inner.stream_completion(request, cx)
4659 }
4660 }
4661
4662 fn as_fake(&self) -> &FakeLanguageModel {
4663 &self.inner
4664 }
4665 }
4666
4667 let model = Arc::new(RetryTestModel {
4668 inner: Arc::new(FakeLanguageModel::default()),
4669 failed_once: Arc::new(Mutex::new(false)),
4670 });
4671
4672 // Insert a user message
4673 thread.update(cx, |thread, cx| {
4674 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4675 });
4676
4677 // Track message deletions
4678 // Track when retry completes successfully
4679 let retry_completed = Arc::new(Mutex::new(false));
4680 let retry_completed_clone = retry_completed.clone();
4681
4682 let _subscription = thread.update(cx, |_, cx| {
4683 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4684 if let ThreadEvent::StreamedCompletion = event {
4685 *retry_completed_clone.lock() = true;
4686 }
4687 })
4688 });
4689
4690 // Start completion
4691 thread.update(cx, |thread, cx| {
4692 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4693 });
4694 cx.run_until_parked();
4695
4696 // Get the retry message ID
4697 let retry_message_id = thread.read_with(cx, |thread, _| {
4698 thread
4699 .messages()
4700 .find(|msg| msg.role == Role::System && msg.ui_only)
4701 .map(|msg| msg.id)
4702 .expect("Should have a retry message")
4703 });
4704
4705 // Wait for retry
4706 cx.executor()
4707 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4708 cx.run_until_parked();
4709
4710 // Stream some successful content
4711 let fake_model = model.as_fake();
4712 // After the retry, there should be a new pending completion
4713 let pending = fake_model.pending_completions();
4714 assert!(
4715 !pending.is_empty(),
4716 "Should have a pending completion after retry"
4717 );
4718 fake_model.stream_completion_response(&pending[0], "Success!");
4719 fake_model.end_completion_stream(&pending[0]);
4720 cx.run_until_parked();
4721
4722 // Check that the retry completed successfully
4723 assert!(
4724 *retry_completed.lock(),
4725 "Retry should have completed successfully"
4726 );
4727
4728 // Retry message should still exist but be marked as ui_only
4729 thread.read_with(cx, |thread, _| {
4730 let retry_msg = thread
4731 .message(retry_message_id)
4732 .expect("Retry message should still exist");
4733 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4734 assert_eq!(
4735 retry_msg.role,
4736 Role::System,
4737 "Retry message should have System role"
4738 );
4739 });
4740 }
4741
4742 #[gpui::test]
4743 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4744 init_test_settings(cx);
4745
4746 let project = create_test_project(cx, json!({})).await;
4747 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4748
4749 // Create a model that fails once then succeeds
4750 struct FailOnceModel {
4751 inner: Arc<FakeLanguageModel>,
4752 failed_once: Arc<Mutex<bool>>,
4753 }
4754
4755 impl LanguageModel for FailOnceModel {
4756 fn id(&self) -> LanguageModelId {
4757 self.inner.id()
4758 }
4759
4760 fn name(&self) -> LanguageModelName {
4761 self.inner.name()
4762 }
4763
4764 fn provider_id(&self) -> LanguageModelProviderId {
4765 self.inner.provider_id()
4766 }
4767
4768 fn provider_name(&self) -> LanguageModelProviderName {
4769 self.inner.provider_name()
4770 }
4771
4772 fn supports_tools(&self) -> bool {
4773 self.inner.supports_tools()
4774 }
4775
4776 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4777 self.inner.supports_tool_choice(choice)
4778 }
4779
4780 fn supports_images(&self) -> bool {
4781 self.inner.supports_images()
4782 }
4783
4784 fn telemetry_id(&self) -> String {
4785 self.inner.telemetry_id()
4786 }
4787
4788 fn max_token_count(&self) -> u64 {
4789 self.inner.max_token_count()
4790 }
4791
4792 fn count_tokens(
4793 &self,
4794 request: LanguageModelRequest,
4795 cx: &App,
4796 ) -> BoxFuture<'static, Result<u64>> {
4797 self.inner.count_tokens(request, cx)
4798 }
4799
4800 fn stream_completion(
4801 &self,
4802 request: LanguageModelRequest,
4803 cx: &AsyncApp,
4804 ) -> BoxFuture<
4805 'static,
4806 Result<
4807 BoxStream<
4808 'static,
4809 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4810 >,
4811 LanguageModelCompletionError,
4812 >,
4813 > {
4814 if !*self.failed_once.lock() {
4815 *self.failed_once.lock() = true;
4816 // Return error on first attempt
4817 let stream = futures::stream::once(async move {
4818 Err(LanguageModelCompletionError::Overloaded)
4819 });
4820 async move { Ok(stream.boxed()) }.boxed()
4821 } else {
4822 // Succeed on retry
4823 self.inner.stream_completion(request, cx)
4824 }
4825 }
4826 }
4827
4828 let fail_once_model = Arc::new(FailOnceModel {
4829 inner: Arc::new(FakeLanguageModel::default()),
4830 failed_once: Arc::new(Mutex::new(false)),
4831 });
4832
4833 // Insert a user message
4834 thread.update(cx, |thread, cx| {
4835 thread.insert_user_message(
4836 "Test message",
4837 ContextLoadResult::default(),
4838 None,
4839 vec![],
4840 cx,
4841 );
4842 });
4843
4844 // Start completion with fail-once model
4845 thread.update(cx, |thread, cx| {
4846 thread.send_to_model(
4847 fail_once_model.clone(),
4848 CompletionIntent::UserPrompt,
4849 None,
4850 cx,
4851 );
4852 });
4853
4854 cx.run_until_parked();
4855
4856 // Verify retry state exists after first failure
4857 thread.read_with(cx, |thread, _| {
4858 assert!(
4859 thread.retry_state.is_some(),
4860 "Should have retry state after failure"
4861 );
4862 });
4863
4864 // Wait for retry delay
4865 cx.executor()
4866 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4867 cx.run_until_parked();
4868
4869 // The retry should now use our FailOnceModel which should succeed
4870 // We need to help the FakeLanguageModel complete the stream
4871 let inner_fake = fail_once_model.inner.clone();
4872
4873 // Wait a bit for the retry to start
4874 cx.run_until_parked();
4875
4876 // Check for pending completions and complete them
4877 if let Some(pending) = inner_fake.pending_completions().first() {
4878 inner_fake.stream_completion_response(pending, "Success!");
4879 inner_fake.end_completion_stream(pending);
4880 }
4881 cx.run_until_parked();
4882
4883 thread.read_with(cx, |thread, _| {
4884 assert!(
4885 thread.retry_state.is_none(),
4886 "Retry state should be cleared after successful completion"
4887 );
4888
4889 let has_assistant_message = thread
4890 .messages
4891 .iter()
4892 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4893 assert!(
4894 has_assistant_message,
4895 "Should have an assistant message after successful retry"
4896 );
4897 });
4898 }
4899
4900 #[gpui::test]
4901 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4902 init_test_settings(cx);
4903
4904 let project = create_test_project(cx, json!({})).await;
4905 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4906
4907 // Create a model that returns rate limit error with retry_after
4908 struct RateLimitModel {
4909 inner: Arc<FakeLanguageModel>,
4910 }
4911
4912 impl LanguageModel for RateLimitModel {
4913 fn id(&self) -> LanguageModelId {
4914 self.inner.id()
4915 }
4916
4917 fn name(&self) -> LanguageModelName {
4918 self.inner.name()
4919 }
4920
4921 fn provider_id(&self) -> LanguageModelProviderId {
4922 self.inner.provider_id()
4923 }
4924
4925 fn provider_name(&self) -> LanguageModelProviderName {
4926 self.inner.provider_name()
4927 }
4928
4929 fn supports_tools(&self) -> bool {
4930 self.inner.supports_tools()
4931 }
4932
4933 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4934 self.inner.supports_tool_choice(choice)
4935 }
4936
4937 fn supports_images(&self) -> bool {
4938 self.inner.supports_images()
4939 }
4940
4941 fn telemetry_id(&self) -> String {
4942 self.inner.telemetry_id()
4943 }
4944
4945 fn max_token_count(&self) -> u64 {
4946 self.inner.max_token_count()
4947 }
4948
4949 fn count_tokens(
4950 &self,
4951 request: LanguageModelRequest,
4952 cx: &App,
4953 ) -> BoxFuture<'static, Result<u64>> {
4954 self.inner.count_tokens(request, cx)
4955 }
4956
4957 fn stream_completion(
4958 &self,
4959 _request: LanguageModelRequest,
4960 _cx: &AsyncApp,
4961 ) -> BoxFuture<
4962 'static,
4963 Result<
4964 BoxStream<
4965 'static,
4966 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4967 >,
4968 LanguageModelCompletionError,
4969 >,
4970 > {
4971 async move {
4972 let stream = futures::stream::once(async move {
4973 Err(LanguageModelCompletionError::RateLimitExceeded {
4974 retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
4975 })
4976 });
4977 Ok(stream.boxed())
4978 }
4979 .boxed()
4980 }
4981
4982 fn as_fake(&self) -> &FakeLanguageModel {
4983 &self.inner
4984 }
4985 }
4986
4987 let model = Arc::new(RateLimitModel {
4988 inner: Arc::new(FakeLanguageModel::default()),
4989 });
4990
4991 // Insert a user message
4992 thread.update(cx, |thread, cx| {
4993 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4994 });
4995
4996 // Start completion
4997 thread.update(cx, |thread, cx| {
4998 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4999 });
5000
5001 cx.run_until_parked();
5002
5003 let retry_count = thread.update(cx, |thread, _| {
5004 thread
5005 .messages
5006 .iter()
5007 .filter(|m| {
5008 m.ui_only
5009 && m.segments.iter().any(|s| {
5010 if let MessageSegment::Text(text) = s {
5011 text.contains("rate limit exceeded")
5012 } else {
5013 false
5014 }
5015 })
5016 })
5017 .count()
5018 });
5019 assert_eq!(retry_count, 1, "Should have scheduled one retry");
5020
5021 thread.read_with(cx, |thread, _| {
5022 assert!(
5023 thread.retry_state.is_none(),
5024 "Rate limit errors should not set retry_state"
5025 );
5026 });
5027
5028 // Verify we have one retry message
5029 thread.read_with(cx, |thread, _| {
5030 let retry_messages = thread
5031 .messages
5032 .iter()
5033 .filter(|msg| {
5034 msg.ui_only
5035 && msg.segments.iter().any(|seg| {
5036 if let MessageSegment::Text(text) = seg {
5037 text.contains("rate limit exceeded")
5038 } else {
5039 false
5040 }
5041 })
5042 })
5043 .count();
5044 assert_eq!(
5045 retry_messages, 1,
5046 "Should have one rate limit retry message"
5047 );
5048 });
5049
5050 // Check that retry message doesn't include attempt count
5051 thread.read_with(cx, |thread, _| {
5052 let retry_message = thread
5053 .messages
5054 .iter()
5055 .find(|msg| msg.role == Role::System && msg.ui_only)
5056 .expect("Should have a retry message");
5057
5058 // Check that the message doesn't contain attempt count
5059 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5060 assert!(
5061 !text.contains("attempt"),
5062 "Rate limit retry message should not contain attempt count"
5063 );
5064 assert!(
5065 text.contains(&format!(
5066 "Retrying in {} seconds",
5067 TEST_RATE_LIMIT_RETRY_SECS
5068 )),
5069 "Rate limit retry message should contain retry delay"
5070 );
5071 }
5072 });
5073 }
5074
5075 #[gpui::test]
5076 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5077 init_test_settings(cx);
5078
5079 let project = create_test_project(cx, json!({})).await;
5080 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5081
5082 // Insert a regular user message
5083 thread.update(cx, |thread, cx| {
5084 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5085 });
5086
5087 // Insert a UI-only message (like our retry notifications)
5088 thread.update(cx, |thread, cx| {
5089 let id = thread.next_message_id.post_inc();
5090 thread.messages.push(Message {
5091 id,
5092 role: Role::System,
5093 segments: vec![MessageSegment::Text(
5094 "This is a UI-only message that should not be sent to the model".to_string(),
5095 )],
5096 loaded_context: LoadedContext::default(),
5097 creases: Vec::new(),
5098 is_hidden: true,
5099 ui_only: true,
5100 });
5101 cx.emit(ThreadEvent::MessageAdded(id));
5102 });
5103
5104 // Insert another regular message
5105 thread.update(cx, |thread, cx| {
5106 thread.insert_user_message(
5107 "How are you?",
5108 ContextLoadResult::default(),
5109 None,
5110 vec![],
5111 cx,
5112 );
5113 });
5114
5115 // Generate the completion request
5116 let request = thread.update(cx, |thread, cx| {
5117 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5118 });
5119
5120 // Verify that the request only contains non-UI-only messages
5121 // Should have system prompt + 2 user messages, but not the UI-only message
5122 let user_messages: Vec<_> = request
5123 .messages
5124 .iter()
5125 .filter(|msg| msg.role == Role::User)
5126 .collect();
5127 assert_eq!(
5128 user_messages.len(),
5129 2,
5130 "Should have exactly 2 user messages"
5131 );
5132
5133 // Verify the UI-only content is not present anywhere in the request
5134 let request_text = request
5135 .messages
5136 .iter()
5137 .flat_map(|msg| &msg.content)
5138 .filter_map(|content| match content {
5139 MessageContent::Text(text) => Some(text.as_str()),
5140 _ => None,
5141 })
5142 .collect::<String>();
5143
5144 assert!(
5145 !request_text.contains("UI-only message"),
5146 "UI-only message content should not be in the request"
5147 );
5148
5149 // Verify the thread still has all 3 messages (including UI-only)
5150 thread.read_with(cx, |thread, _| {
5151 assert_eq!(
5152 thread.messages().count(),
5153 3,
5154 "Thread should have 3 messages"
5155 );
5156 assert_eq!(
5157 thread.messages().filter(|m| m.ui_only).count(),
5158 1,
5159 "Thread should have 1 UI-only message"
5160 );
5161 });
5162
5163 // Verify that UI-only messages are not serialized
5164 let serialized = thread
5165 .update(cx, |thread, cx| thread.serialize(cx))
5166 .await
5167 .unwrap();
5168 assert_eq!(
5169 serialized.messages.len(),
5170 2,
5171 "Serialized thread should only have 2 messages (no UI-only)"
5172 );
5173 }
5174
5175 #[gpui::test]
5176 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5177 init_test_settings(cx);
5178
5179 let project = create_test_project(cx, json!({})).await;
5180 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5181
5182 // Create model that returns overloaded error
5183 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5184
5185 // Insert a user message
5186 thread.update(cx, |thread, cx| {
5187 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5188 });
5189
5190 // Start completion
5191 thread.update(cx, |thread, cx| {
5192 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5193 });
5194
5195 cx.run_until_parked();
5196
5197 // Verify retry was scheduled by checking for retry message
5198 let has_retry_message = thread.read_with(cx, |thread, _| {
5199 thread.messages.iter().any(|m| {
5200 m.ui_only
5201 && m.segments.iter().any(|s| {
5202 if let MessageSegment::Text(text) = s {
5203 text.contains("Retrying") && text.contains("seconds")
5204 } else {
5205 false
5206 }
5207 })
5208 })
5209 });
5210 assert!(has_retry_message, "Should have scheduled a retry");
5211
5212 // Cancel the completion before the retry happens
5213 thread.update(cx, |thread, cx| {
5214 thread.cancel_last_completion(None, cx);
5215 });
5216
5217 cx.run_until_parked();
5218
5219 // The retry should not have happened - no pending completions
5220 let fake_model = model.as_fake();
5221 assert_eq!(
5222 fake_model.pending_completions().len(),
5223 0,
5224 "Should have no pending completions after cancellation"
5225 );
5226
5227 // Verify the retry was cancelled by checking retry state
5228 thread.read_with(cx, |thread, _| {
5229 if let Some(retry_state) = &thread.retry_state {
5230 panic!(
5231 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5232 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5233 );
5234 }
5235 });
5236 }
5237
5238 fn test_summarize_error(
5239 model: &Arc<dyn LanguageModel>,
5240 thread: &Entity<Thread>,
5241 cx: &mut TestAppContext,
5242 ) {
5243 thread.update(cx, |thread, cx| {
5244 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5245 thread.send_to_model(
5246 model.clone(),
5247 CompletionIntent::ThreadSummarization,
5248 None,
5249 cx,
5250 );
5251 });
5252
5253 let fake_model = model.as_fake();
5254 simulate_successful_response(&fake_model, cx);
5255
5256 thread.read_with(cx, |thread, _| {
5257 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5258 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5259 });
5260
5261 // Simulate summary request ending
5262 cx.run_until_parked();
5263 fake_model.end_last_completion_stream();
5264 cx.run_until_parked();
5265
5266 // State is set to Error and default message
5267 thread.read_with(cx, |thread, _| {
5268 assert!(matches!(thread.summary(), ThreadSummary::Error));
5269 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5270 });
5271 }
5272
5273 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5274 cx.run_until_parked();
5275 fake_model.stream_last_completion_response("Assistant response");
5276 fake_model.end_last_completion_stream();
5277 cx.run_until_parked();
5278 }
5279
5280 fn init_test_settings(cx: &mut TestAppContext) {
5281 cx.update(|cx| {
5282 let settings_store = SettingsStore::test(cx);
5283 cx.set_global(settings_store);
5284 language::init(cx);
5285 Project::init_settings(cx);
5286 AgentSettings::register(cx);
5287 prompt_store::init(cx);
5288 thread_store::init(cx);
5289 workspace::init_settings(cx);
5290 language_model::init_settings(cx);
5291 ThemeSettings::register(cx);
5292 ToolRegistry::default_global(cx);
5293 });
5294 }
5295
5296 // Helper to create a test project with test files
5297 async fn create_test_project(
5298 cx: &mut TestAppContext,
5299 files: serde_json::Value,
5300 ) -> Entity<Project> {
5301 let fs = FakeFs::new(cx.executor());
5302 fs.insert_tree(path!("/test"), files).await;
5303 Project::test(fs, [path!("/test").as_ref()], cx).await
5304 }
5305
5306 async fn setup_test_environment(
5307 cx: &mut TestAppContext,
5308 project: Entity<Project>,
5309 ) -> (
5310 Entity<Workspace>,
5311 Entity<ThreadStore>,
5312 Entity<Thread>,
5313 Entity<ContextStore>,
5314 Arc<dyn LanguageModel>,
5315 ) {
5316 let (workspace, cx) =
5317 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5318
5319 let thread_store = cx
5320 .update(|_, cx| {
5321 ThreadStore::load(
5322 project.clone(),
5323 cx.new(|_| ToolWorkingSet::default()),
5324 None,
5325 Arc::new(PromptBuilder::new(None).unwrap()),
5326 cx,
5327 )
5328 })
5329 .await
5330 .unwrap();
5331
5332 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5333 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5334
5335 let provider = Arc::new(FakeLanguageModelProvider);
5336 let model = provider.test_model();
5337 let model: Arc<dyn LanguageModel> = Arc::new(model);
5338
5339 cx.update(|_, cx| {
5340 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5341 registry.set_default_model(
5342 Some(ConfiguredModel {
5343 provider: provider.clone(),
5344 model: model.clone(),
5345 }),
5346 cx,
5347 );
5348 registry.set_thread_summary_model(
5349 Some(ConfiguredModel {
5350 provider,
5351 model: model.clone(),
5352 }),
5353 cx,
5354 );
5355 })
5356 });
5357
5358 (workspace, thread_store, thread, context_store, model)
5359 }
5360
5361 async fn add_file_to_context(
5362 project: &Entity<Project>,
5363 context_store: &Entity<ContextStore>,
5364 path: &str,
5365 cx: &mut TestAppContext,
5366 ) -> Result<Entity<language::Buffer>> {
5367 let buffer_path = project
5368 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5369 .unwrap();
5370
5371 let buffer = project
5372 .update(cx, |project, cx| {
5373 project.open_buffer(buffer_path.clone(), cx)
5374 })
5375 .await
5376 .unwrap();
5377
5378 context_store.update(cx, |context_store, cx| {
5379 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5380 });
5381
5382 Ok(buffer)
5383 }
5384}