1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::{HashMap, HashSet};
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
28 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
29 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
30 TokenUsage,
31};
32use postage::stream::Stream as _;
33use project::{
34 Project,
35 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
36};
37use prompt_store::{ModelContext, PromptBuilder};
38use proto::Plan;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use settings::Settings;
42use std::{
43 io::Write,
44 ops::Range,
45 sync::Arc,
46 time::{Duration, Instant},
47};
48use thiserror::Error;
49use util::{ResultExt as _, post_inc};
50use uuid::Uuid;
51use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
52
53const MAX_RETRY_ATTEMPTS: u8 = 3;
54const BASE_RETRY_DELAY_SECS: u64 = 5;
55
56#[derive(
57 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
58)]
59pub struct ThreadId(Arc<str>);
60
61impl ThreadId {
62 pub fn new() -> Self {
63 Self(Uuid::new_v4().to_string().into())
64 }
65}
66
67impl std::fmt::Display for ThreadId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73impl From<&str> for ThreadId {
74 fn from(value: &str) -> Self {
75 Self(value.into())
76 }
77}
78
79/// The ID of the user prompt that initiated a request.
80///
81/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
82#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
83pub struct PromptId(Arc<str>);
84
85impl PromptId {
86 pub fn new() -> Self {
87 Self(Uuid::new_v4().to_string().into())
88 }
89}
90
91impl std::fmt::Display for PromptId {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{}", self.0)
94 }
95}
96
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
98pub struct MessageId(pub(crate) usize);
99
100impl MessageId {
101 fn post_inc(&mut self) -> Self {
102 Self(post_inc(&mut self.0))
103 }
104
105 pub fn as_usize(&self) -> usize {
106 self.0
107 }
108}
109
110/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
111#[derive(Clone, Debug)]
112pub struct MessageCrease {
113 pub range: Range<usize>,
114 pub icon_path: SharedString,
115 pub label: SharedString,
116 /// None for a deserialized message, Some otherwise.
117 pub context: Option<AgentContextHandle>,
118}
119
120/// A message in a [`Thread`].
121#[derive(Debug, Clone)]
122pub struct Message {
123 pub id: MessageId,
124 pub role: Role,
125 pub segments: Vec<MessageSegment>,
126 pub loaded_context: LoadedContext,
127 pub creases: Vec<MessageCrease>,
128 pub is_hidden: bool,
129 pub ui_only: bool,
130}
131
132impl Message {
133 /// Returns whether the message contains any meaningful text that should be displayed
134 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
135 pub fn should_display_content(&self) -> bool {
136 self.segments.iter().all(|segment| segment.should_display())
137 }
138
139 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
140 if let Some(MessageSegment::Thinking {
141 text: segment,
142 signature: current_signature,
143 }) = self.segments.last_mut()
144 {
145 if let Some(signature) = signature {
146 *current_signature = Some(signature);
147 }
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Thinking {
151 text: text.to_string(),
152 signature,
153 });
154 }
155 }
156
157 pub fn push_redacted_thinking(&mut self, data: String) {
158 self.segments.push(MessageSegment::RedactedThinking(data));
159 }
160
161 pub fn push_text(&mut self, text: &str) {
162 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Text(text.to_string()));
166 }
167 }
168
169 pub fn to_string(&self) -> String {
170 let mut result = String::new();
171
172 if !self.loaded_context.text.is_empty() {
173 result.push_str(&self.loaded_context.text);
174 }
175
176 for segment in &self.segments {
177 match segment {
178 MessageSegment::Text(text) => result.push_str(text),
179 MessageSegment::Thinking { text, .. } => {
180 result.push_str("<think>\n");
181 result.push_str(text);
182 result.push_str("\n</think>");
183 }
184 MessageSegment::RedactedThinking(_) => {}
185 }
186 }
187
188 result
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub enum MessageSegment {
194 Text(String),
195 Thinking {
196 text: String,
197 signature: Option<String>,
198 },
199 RedactedThinking(String),
200}
201
202impl MessageSegment {
203 pub fn should_display(&self) -> bool {
204 match self {
205 Self::Text(text) => text.is_empty(),
206 Self::Thinking { text, .. } => text.is_empty(),
207 Self::RedactedThinking(_) => false,
208 }
209 }
210
211 pub fn text(&self) -> Option<&str> {
212 match self {
213 MessageSegment::Text(text) => Some(text),
214 _ => None,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
220pub struct ProjectSnapshot {
221 pub worktree_snapshots: Vec<WorktreeSnapshot>,
222 pub unsaved_buffer_paths: Vec<String>,
223 pub timestamp: DateTime<Utc>,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227pub struct WorktreeSnapshot {
228 pub worktree_path: String,
229 pub git_state: Option<GitState>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
233pub struct GitState {
234 pub remote_url: Option<String>,
235 pub head_sha: Option<String>,
236 pub current_branch: Option<String>,
237 pub diff: Option<String>,
238}
239
240#[derive(Clone, Debug)]
241pub struct ThreadCheckpoint {
242 message_id: MessageId,
243 git_checkpoint: GitStoreCheckpoint,
244}
245
246#[derive(Copy, Clone, Debug, PartialEq, Eq)]
247pub enum ThreadFeedback {
248 Positive,
249 Negative,
250}
251
252pub enum LastRestoreCheckpoint {
253 Pending {
254 message_id: MessageId,
255 },
256 Error {
257 message_id: MessageId,
258 error: String,
259 },
260}
261
262impl LastRestoreCheckpoint {
263 pub fn message_id(&self) -> MessageId {
264 match self {
265 LastRestoreCheckpoint::Pending { message_id } => *message_id,
266 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
267 }
268 }
269}
270
271#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
272pub enum DetailedSummaryState {
273 #[default]
274 NotGenerated,
275 Generating {
276 message_id: MessageId,
277 },
278 Generated {
279 text: SharedString,
280 message_id: MessageId,
281 },
282}
283
284impl DetailedSummaryState {
285 fn text(&self) -> Option<SharedString> {
286 if let Self::Generated { text, .. } = self {
287 Some(text.clone())
288 } else {
289 None
290 }
291 }
292}
293
294#[derive(Default, Debug)]
295pub struct TotalTokenUsage {
296 pub total: u64,
297 pub max: u64,
298}
299
300impl TotalTokenUsage {
301 pub fn ratio(&self) -> TokenUsageRatio {
302 #[cfg(debug_assertions)]
303 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
304 .unwrap_or("0.8".to_string())
305 .parse()
306 .unwrap();
307 #[cfg(not(debug_assertions))]
308 let warning_threshold: f32 = 0.8;
309
310 // When the maximum is unknown because there is no selected model,
311 // avoid showing the token limit warning.
312 if self.max == 0 {
313 TokenUsageRatio::Normal
314 } else if self.total >= self.max {
315 TokenUsageRatio::Exceeded
316 } else if self.total as f32 / self.max as f32 >= warning_threshold {
317 TokenUsageRatio::Warning
318 } else {
319 TokenUsageRatio::Normal
320 }
321 }
322
323 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
324 TotalTokenUsage {
325 total: self.total + tokens,
326 max: self.max,
327 }
328 }
329}
330
331#[derive(Debug, Default, PartialEq, Eq)]
332pub enum TokenUsageRatio {
333 #[default]
334 Normal,
335 Warning,
336 Exceeded,
337}
338
339#[derive(Debug, Clone, Copy)]
340pub enum QueueState {
341 Sending,
342 Queued { position: usize },
343 Started,
344}
345
346/// A thread of conversation with the LLM.
347pub struct Thread {
348 id: ThreadId,
349 updated_at: DateTime<Utc>,
350 summary: ThreadSummary,
351 pending_summary: Task<Option<()>>,
352 detailed_summary_task: Task<Option<()>>,
353 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
354 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
355 completion_mode: agent_settings::CompletionMode,
356 messages: Vec<Message>,
357 next_message_id: MessageId,
358 last_prompt_id: PromptId,
359 project_context: SharedProjectContext,
360 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
361 completion_count: usize,
362 pending_completions: Vec<PendingCompletion>,
363 project: Entity<Project>,
364 prompt_builder: Arc<PromptBuilder>,
365 tools: Entity<ToolWorkingSet>,
366 tool_use: ToolUseState,
367 action_log: Entity<ActionLog>,
368 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
369 pending_checkpoint: Option<ThreadCheckpoint>,
370 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
371 request_token_usage: Vec<TokenUsage>,
372 cumulative_token_usage: TokenUsage,
373 exceeded_window_error: Option<ExceededWindowError>,
374 tool_use_limit_reached: bool,
375 feedback: Option<ThreadFeedback>,
376 retry_state: Option<RetryState>,
377 message_feedback: HashMap<MessageId, ThreadFeedback>,
378 last_auto_capture_at: Option<Instant>,
379 last_received_chunk_at: Option<Instant>,
380 request_callback: Option<
381 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
382 >,
383 remaining_turns: u32,
384 configured_model: Option<ConfiguredModel>,
385 profile: AgentProfile,
386}
387
388#[derive(Clone, Debug)]
389struct RetryState {
390 attempt: u8,
391 max_attempts: u8,
392 intent: CompletionIntent,
393}
394
395#[derive(Clone, Debug, PartialEq, Eq)]
396pub enum ThreadSummary {
397 Pending,
398 Generating,
399 Ready(SharedString),
400 Error,
401}
402
403impl ThreadSummary {
404 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
405
406 pub fn or_default(&self) -> SharedString {
407 self.unwrap_or(Self::DEFAULT)
408 }
409
410 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
411 self.ready().unwrap_or_else(|| message.into())
412 }
413
414 pub fn ready(&self) -> Option<SharedString> {
415 match self {
416 ThreadSummary::Ready(summary) => Some(summary.clone()),
417 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
418 }
419 }
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
423pub struct ExceededWindowError {
424 /// Model used when last message exceeded context window
425 model_id: LanguageModelId,
426 /// Token count including last message
427 token_count: u64,
428}
429
430impl Thread {
431 pub fn new(
432 project: Entity<Project>,
433 tools: Entity<ToolWorkingSet>,
434 prompt_builder: Arc<PromptBuilder>,
435 system_prompt: SharedProjectContext,
436 cx: &mut Context<Self>,
437 ) -> Self {
438 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
439 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
440 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
441
442 Self {
443 id: ThreadId::new(),
444 updated_at: Utc::now(),
445 summary: ThreadSummary::Pending,
446 pending_summary: Task::ready(None),
447 detailed_summary_task: Task::ready(None),
448 detailed_summary_tx,
449 detailed_summary_rx,
450 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
451 messages: Vec::new(),
452 next_message_id: MessageId(0),
453 last_prompt_id: PromptId::new(),
454 project_context: system_prompt,
455 checkpoints_by_message: HashMap::default(),
456 completion_count: 0,
457 pending_completions: Vec::new(),
458 project: project.clone(),
459 prompt_builder,
460 tools: tools.clone(),
461 last_restore_checkpoint: None,
462 pending_checkpoint: None,
463 tool_use: ToolUseState::new(tools.clone()),
464 action_log: cx.new(|_| ActionLog::new(project.clone())),
465 initial_project_snapshot: {
466 let project_snapshot = Self::project_snapshot(project, cx);
467 cx.foreground_executor()
468 .spawn(async move { Some(project_snapshot.await) })
469 .shared()
470 },
471 request_token_usage: Vec::new(),
472 cumulative_token_usage: TokenUsage::default(),
473 exceeded_window_error: None,
474 tool_use_limit_reached: false,
475 feedback: None,
476 retry_state: None,
477 message_feedback: HashMap::default(),
478 last_auto_capture_at: None,
479 last_received_chunk_at: None,
480 request_callback: None,
481 remaining_turns: u32::MAX,
482 configured_model,
483 profile: AgentProfile::new(profile_id, tools),
484 }
485 }
486
487 pub fn deserialize(
488 id: ThreadId,
489 serialized: SerializedThread,
490 project: Entity<Project>,
491 tools: Entity<ToolWorkingSet>,
492 prompt_builder: Arc<PromptBuilder>,
493 project_context: SharedProjectContext,
494 window: Option<&mut Window>, // None in headless mode
495 cx: &mut Context<Self>,
496 ) -> Self {
497 let next_message_id = MessageId(
498 serialized
499 .messages
500 .last()
501 .map(|message| message.id.0 + 1)
502 .unwrap_or(0),
503 );
504 let tool_use = ToolUseState::from_serialized_messages(
505 tools.clone(),
506 &serialized.messages,
507 project.clone(),
508 window,
509 cx,
510 );
511 let (detailed_summary_tx, detailed_summary_rx) =
512 postage::watch::channel_with(serialized.detailed_summary_state);
513
514 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
515 serialized
516 .model
517 .and_then(|model| {
518 let model = SelectedModel {
519 provider: model.provider.clone().into(),
520 model: model.model.clone().into(),
521 };
522 registry.select_model(&model, cx)
523 })
524 .or_else(|| registry.default_model())
525 });
526
527 let completion_mode = serialized
528 .completion_mode
529 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
530 let profile_id = serialized
531 .profile
532 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
533
534 Self {
535 id,
536 updated_at: serialized.updated_at,
537 summary: ThreadSummary::Ready(serialized.summary),
538 pending_summary: Task::ready(None),
539 detailed_summary_task: Task::ready(None),
540 detailed_summary_tx,
541 detailed_summary_rx,
542 completion_mode,
543 retry_state: None,
544 messages: serialized
545 .messages
546 .into_iter()
547 .map(|message| Message {
548 id: message.id,
549 role: message.role,
550 segments: message
551 .segments
552 .into_iter()
553 .map(|segment| match segment {
554 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
555 SerializedMessageSegment::Thinking { text, signature } => {
556 MessageSegment::Thinking { text, signature }
557 }
558 SerializedMessageSegment::RedactedThinking { data } => {
559 MessageSegment::RedactedThinking(data)
560 }
561 })
562 .collect(),
563 loaded_context: LoadedContext {
564 contexts: Vec::new(),
565 text: message.context,
566 images: Vec::new(),
567 },
568 creases: message
569 .creases
570 .into_iter()
571 .map(|crease| MessageCrease {
572 range: crease.start..crease.end,
573 icon_path: crease.icon_path,
574 label: crease.label,
575 context: None,
576 })
577 .collect(),
578 is_hidden: message.is_hidden,
579 ui_only: false, // UI-only messages are not persisted
580 })
581 .collect(),
582 next_message_id,
583 last_prompt_id: PromptId::new(),
584 project_context,
585 checkpoints_by_message: HashMap::default(),
586 completion_count: 0,
587 pending_completions: Vec::new(),
588 last_restore_checkpoint: None,
589 pending_checkpoint: None,
590 project: project.clone(),
591 prompt_builder,
592 tools: tools.clone(),
593 tool_use,
594 action_log: cx.new(|_| ActionLog::new(project)),
595 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
596 request_token_usage: serialized.request_token_usage,
597 cumulative_token_usage: serialized.cumulative_token_usage,
598 exceeded_window_error: None,
599 tool_use_limit_reached: serialized.tool_use_limit_reached,
600 feedback: None,
601 message_feedback: HashMap::default(),
602 last_auto_capture_at: None,
603 last_received_chunk_at: None,
604 request_callback: None,
605 remaining_turns: u32::MAX,
606 configured_model,
607 profile: AgentProfile::new(profile_id, tools),
608 }
609 }
610
611 pub fn set_request_callback(
612 &mut self,
613 callback: impl 'static
614 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
615 ) {
616 self.request_callback = Some(Box::new(callback));
617 }
618
619 pub fn id(&self) -> &ThreadId {
620 &self.id
621 }
622
623 pub fn profile(&self) -> &AgentProfile {
624 &self.profile
625 }
626
627 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
628 if &id != self.profile.id() {
629 self.profile = AgentProfile::new(id, self.tools.clone());
630 cx.emit(ThreadEvent::ProfileChanged);
631 }
632 }
633
634 pub fn is_empty(&self) -> bool {
635 self.messages.is_empty()
636 }
637
638 pub fn updated_at(&self) -> DateTime<Utc> {
639 self.updated_at
640 }
641
642 pub fn touch_updated_at(&mut self) {
643 self.updated_at = Utc::now();
644 }
645
646 pub fn advance_prompt_id(&mut self) {
647 self.last_prompt_id = PromptId::new();
648 }
649
650 pub fn project_context(&self) -> SharedProjectContext {
651 self.project_context.clone()
652 }
653
654 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
655 if self.configured_model.is_none() {
656 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
657 }
658 self.configured_model.clone()
659 }
660
661 pub fn configured_model(&self) -> Option<ConfiguredModel> {
662 self.configured_model.clone()
663 }
664
665 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
666 self.configured_model = model;
667 cx.notify();
668 }
669
670 pub fn summary(&self) -> &ThreadSummary {
671 &self.summary
672 }
673
674 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
675 let current_summary = match &self.summary {
676 ThreadSummary::Pending | ThreadSummary::Generating => return,
677 ThreadSummary::Ready(summary) => summary,
678 ThreadSummary::Error => &ThreadSummary::DEFAULT,
679 };
680
681 let mut new_summary = new_summary.into();
682
683 if new_summary.is_empty() {
684 new_summary = ThreadSummary::DEFAULT;
685 }
686
687 if current_summary != &new_summary {
688 self.summary = ThreadSummary::Ready(new_summary);
689 cx.emit(ThreadEvent::SummaryChanged);
690 }
691 }
692
693 pub fn completion_mode(&self) -> CompletionMode {
694 self.completion_mode
695 }
696
697 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
698 self.completion_mode = mode;
699 }
700
701 pub fn message(&self, id: MessageId) -> Option<&Message> {
702 let index = self
703 .messages
704 .binary_search_by(|message| message.id.cmp(&id))
705 .ok()?;
706
707 self.messages.get(index)
708 }
709
710 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
711 self.messages.iter()
712 }
713
714 pub fn is_generating(&self) -> bool {
715 !self.pending_completions.is_empty() || !self.all_tools_finished()
716 }
717
718 /// Indicates whether streaming of language model events is stale.
719 /// When `is_generating()` is false, this method returns `None`.
720 pub fn is_generation_stale(&self) -> Option<bool> {
721 const STALE_THRESHOLD: u128 = 250;
722
723 self.last_received_chunk_at
724 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
725 }
726
727 fn received_chunk(&mut self) {
728 self.last_received_chunk_at = Some(Instant::now());
729 }
730
731 pub fn queue_state(&self) -> Option<QueueState> {
732 self.pending_completions
733 .first()
734 .map(|pending_completion| pending_completion.queue_state)
735 }
736
737 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
738 &self.tools
739 }
740
741 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
742 self.tool_use
743 .pending_tool_uses()
744 .into_iter()
745 .find(|tool_use| &tool_use.id == id)
746 }
747
748 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
749 self.tool_use
750 .pending_tool_uses()
751 .into_iter()
752 .filter(|tool_use| tool_use.status.needs_confirmation())
753 }
754
755 pub fn has_pending_tool_uses(&self) -> bool {
756 !self.tool_use.pending_tool_uses().is_empty()
757 }
758
759 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
760 self.checkpoints_by_message.get(&id).cloned()
761 }
762
763 pub fn restore_checkpoint(
764 &mut self,
765 checkpoint: ThreadCheckpoint,
766 cx: &mut Context<Self>,
767 ) -> Task<Result<()>> {
768 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
769 message_id: checkpoint.message_id,
770 });
771 cx.emit(ThreadEvent::CheckpointChanged);
772 cx.notify();
773
774 let git_store = self.project().read(cx).git_store().clone();
775 let restore = git_store.update(cx, |git_store, cx| {
776 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
777 });
778
779 cx.spawn(async move |this, cx| {
780 let result = restore.await;
781 this.update(cx, |this, cx| {
782 if let Err(err) = result.as_ref() {
783 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
784 message_id: checkpoint.message_id,
785 error: err.to_string(),
786 });
787 } else {
788 this.truncate(checkpoint.message_id, cx);
789 this.last_restore_checkpoint = None;
790 }
791 this.pending_checkpoint = None;
792 cx.emit(ThreadEvent::CheckpointChanged);
793 cx.notify();
794 })?;
795 result
796 })
797 }
798
799 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
800 let pending_checkpoint = if self.is_generating() {
801 return;
802 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
803 checkpoint
804 } else {
805 return;
806 };
807
808 self.finalize_checkpoint(pending_checkpoint, cx);
809 }
810
811 fn finalize_checkpoint(
812 &mut self,
813 pending_checkpoint: ThreadCheckpoint,
814 cx: &mut Context<Self>,
815 ) {
816 let git_store = self.project.read(cx).git_store().clone();
817 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
818 cx.spawn(async move |this, cx| match final_checkpoint.await {
819 Ok(final_checkpoint) => {
820 let equal = git_store
821 .update(cx, |store, cx| {
822 store.compare_checkpoints(
823 pending_checkpoint.git_checkpoint.clone(),
824 final_checkpoint.clone(),
825 cx,
826 )
827 })?
828 .await
829 .unwrap_or(false);
830
831 if !equal {
832 this.update(cx, |this, cx| {
833 this.insert_checkpoint(pending_checkpoint, cx)
834 })?;
835 }
836
837 Ok(())
838 }
839 Err(_) => this.update(cx, |this, cx| {
840 this.insert_checkpoint(pending_checkpoint, cx)
841 }),
842 })
843 .detach();
844 }
845
846 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
847 self.checkpoints_by_message
848 .insert(checkpoint.message_id, checkpoint);
849 cx.emit(ThreadEvent::CheckpointChanged);
850 cx.notify();
851 }
852
853 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
854 self.last_restore_checkpoint.as_ref()
855 }
856
857 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
858 let Some(message_ix) = self
859 .messages
860 .iter()
861 .rposition(|message| message.id == message_id)
862 else {
863 return;
864 };
865 for deleted_message in self.messages.drain(message_ix..) {
866 self.checkpoints_by_message.remove(&deleted_message.id);
867 }
868 cx.notify();
869 }
870
871 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
872 self.messages
873 .iter()
874 .find(|message| message.id == id)
875 .into_iter()
876 .flat_map(|message| message.loaded_context.contexts.iter())
877 }
878
879 pub fn is_turn_end(&self, ix: usize) -> bool {
880 if self.messages.is_empty() {
881 return false;
882 }
883
884 if !self.is_generating() && ix == self.messages.len() - 1 {
885 return true;
886 }
887
888 let Some(message) = self.messages.get(ix) else {
889 return false;
890 };
891
892 if message.role != Role::Assistant {
893 return false;
894 }
895
896 self.messages
897 .get(ix + 1)
898 .and_then(|message| {
899 self.message(message.id)
900 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
901 })
902 .unwrap_or(false)
903 }
904
905 pub fn tool_use_limit_reached(&self) -> bool {
906 self.tool_use_limit_reached
907 }
908
909 /// Returns whether all of the tool uses have finished running.
910 pub fn all_tools_finished(&self) -> bool {
911 // If the only pending tool uses left are the ones with errors, then
912 // that means that we've finished running all of the pending tools.
913 self.tool_use
914 .pending_tool_uses()
915 .iter()
916 .all(|pending_tool_use| pending_tool_use.status.is_error())
917 }
918
919 /// Returns whether any pending tool uses may perform edits
920 pub fn has_pending_edit_tool_uses(&self) -> bool {
921 self.tool_use
922 .pending_tool_uses()
923 .iter()
924 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
925 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
926 }
927
928 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
929 self.tool_use.tool_uses_for_message(id, cx)
930 }
931
932 pub fn tool_results_for_message(
933 &self,
934 assistant_message_id: MessageId,
935 ) -> Vec<&LanguageModelToolResult> {
936 self.tool_use.tool_results_for_message(assistant_message_id)
937 }
938
939 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
940 self.tool_use.tool_result(id)
941 }
942
943 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
944 match &self.tool_use.tool_result(id)?.content {
945 LanguageModelToolResultContent::Text(text) => Some(text),
946 LanguageModelToolResultContent::Image(_) => {
947 // TODO: We should display image
948 None
949 }
950 }
951 }
952
953 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
954 self.tool_use.tool_result_card(id).cloned()
955 }
956
957 /// Return tools that are both enabled and supported by the model
958 pub fn available_tools(
959 &self,
960 cx: &App,
961 model: Arc<dyn LanguageModel>,
962 ) -> Vec<LanguageModelRequestTool> {
963 if model.supports_tools() {
964 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
965 .into_iter()
966 .filter_map(|(name, tool)| {
967 // Skip tools that cannot be supported
968 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
969 Some(LanguageModelRequestTool {
970 name,
971 description: tool.description(),
972 input_schema,
973 })
974 })
975 .collect()
976 } else {
977 Vec::default()
978 }
979 }
980
981 pub fn insert_user_message(
982 &mut self,
983 text: impl Into<String>,
984 loaded_context: ContextLoadResult,
985 git_checkpoint: Option<GitStoreCheckpoint>,
986 creases: Vec<MessageCrease>,
987 cx: &mut Context<Self>,
988 ) -> MessageId {
989 if !loaded_context.referenced_buffers.is_empty() {
990 self.action_log.update(cx, |log, cx| {
991 for buffer in loaded_context.referenced_buffers {
992 log.buffer_read(buffer, cx);
993 }
994 });
995 }
996
997 let message_id = self.insert_message(
998 Role::User,
999 vec![MessageSegment::Text(text.into())],
1000 loaded_context.loaded_context,
1001 creases,
1002 false,
1003 cx,
1004 );
1005
1006 if let Some(git_checkpoint) = git_checkpoint {
1007 self.pending_checkpoint = Some(ThreadCheckpoint {
1008 message_id,
1009 git_checkpoint,
1010 });
1011 }
1012
1013 self.auto_capture_telemetry(cx);
1014
1015 message_id
1016 }
1017
1018 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019 let id = self.insert_message(
1020 Role::User,
1021 vec![MessageSegment::Text("Continue where you left off".into())],
1022 LoadedContext::default(),
1023 vec![],
1024 true,
1025 cx,
1026 );
1027 self.pending_checkpoint = None;
1028
1029 id
1030 }
1031
1032 pub fn insert_assistant_message(
1033 &mut self,
1034 segments: Vec<MessageSegment>,
1035 cx: &mut Context<Self>,
1036 ) -> MessageId {
1037 self.insert_message(
1038 Role::Assistant,
1039 segments,
1040 LoadedContext::default(),
1041 Vec::new(),
1042 false,
1043 cx,
1044 )
1045 }
1046
1047 pub fn insert_message(
1048 &mut self,
1049 role: Role,
1050 segments: Vec<MessageSegment>,
1051 loaded_context: LoadedContext,
1052 creases: Vec<MessageCrease>,
1053 is_hidden: bool,
1054 cx: &mut Context<Self>,
1055 ) -> MessageId {
1056 let id = self.next_message_id.post_inc();
1057 self.messages.push(Message {
1058 id,
1059 role,
1060 segments,
1061 loaded_context,
1062 creases,
1063 is_hidden,
1064 ui_only: false,
1065 });
1066 self.touch_updated_at();
1067 cx.emit(ThreadEvent::MessageAdded(id));
1068 id
1069 }
1070
1071 pub fn edit_message(
1072 &mut self,
1073 id: MessageId,
1074 new_role: Role,
1075 new_segments: Vec<MessageSegment>,
1076 creases: Vec<MessageCrease>,
1077 loaded_context: Option<LoadedContext>,
1078 checkpoint: Option<GitStoreCheckpoint>,
1079 cx: &mut Context<Self>,
1080 ) -> bool {
1081 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082 return false;
1083 };
1084 message.role = new_role;
1085 message.segments = new_segments;
1086 message.creases = creases;
1087 if let Some(context) = loaded_context {
1088 message.loaded_context = context;
1089 }
1090 if let Some(git_checkpoint) = checkpoint {
1091 self.checkpoints_by_message.insert(
1092 id,
1093 ThreadCheckpoint {
1094 message_id: id,
1095 git_checkpoint,
1096 },
1097 );
1098 }
1099 self.touch_updated_at();
1100 cx.emit(ThreadEvent::MessageEdited(id));
1101 true
1102 }
1103
1104 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106 return false;
1107 };
1108 self.messages.remove(index);
1109 self.touch_updated_at();
1110 cx.emit(ThreadEvent::MessageDeleted(id));
1111 true
1112 }
1113
1114 /// Returns the representation of this [`Thread`] in a textual form.
1115 ///
1116 /// This is the representation we use when attaching a thread as context to another thread.
1117 pub fn text(&self) -> String {
1118 let mut text = String::new();
1119
1120 for message in &self.messages {
1121 text.push_str(match message.role {
1122 language_model::Role::User => "User:",
1123 language_model::Role::Assistant => "Agent:",
1124 language_model::Role::System => "System:",
1125 });
1126 text.push('\n');
1127
1128 for segment in &message.segments {
1129 match segment {
1130 MessageSegment::Text(content) => text.push_str(content),
1131 MessageSegment::Thinking { text: content, .. } => {
1132 text.push_str(&format!("<think>{}</think>", content))
1133 }
1134 MessageSegment::RedactedThinking(_) => {}
1135 }
1136 }
1137 text.push('\n');
1138 }
1139
1140 text
1141 }
1142
1143 /// Serializes this thread into a format for storage or telemetry.
1144 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145 let initial_project_snapshot = self.initial_project_snapshot.clone();
1146 cx.spawn(async move |this, cx| {
1147 let initial_project_snapshot = initial_project_snapshot.await;
1148 this.read_with(cx, |this, cx| SerializedThread {
1149 version: SerializedThread::VERSION.to_string(),
1150 summary: this.summary().or_default(),
1151 updated_at: this.updated_at(),
1152 messages: this
1153 .messages()
1154 .filter(|message| !message.ui_only)
1155 .map(|message| SerializedMessage {
1156 id: message.id,
1157 role: message.role,
1158 segments: message
1159 .segments
1160 .iter()
1161 .map(|segment| match segment {
1162 MessageSegment::Text(text) => {
1163 SerializedMessageSegment::Text { text: text.clone() }
1164 }
1165 MessageSegment::Thinking { text, signature } => {
1166 SerializedMessageSegment::Thinking {
1167 text: text.clone(),
1168 signature: signature.clone(),
1169 }
1170 }
1171 MessageSegment::RedactedThinking(data) => {
1172 SerializedMessageSegment::RedactedThinking {
1173 data: data.clone(),
1174 }
1175 }
1176 })
1177 .collect(),
1178 tool_uses: this
1179 .tool_uses_for_message(message.id, cx)
1180 .into_iter()
1181 .map(|tool_use| SerializedToolUse {
1182 id: tool_use.id,
1183 name: tool_use.name,
1184 input: tool_use.input,
1185 })
1186 .collect(),
1187 tool_results: this
1188 .tool_results_for_message(message.id)
1189 .into_iter()
1190 .map(|tool_result| SerializedToolResult {
1191 tool_use_id: tool_result.tool_use_id.clone(),
1192 is_error: tool_result.is_error,
1193 content: tool_result.content.clone(),
1194 output: tool_result.output.clone(),
1195 })
1196 .collect(),
1197 context: message.loaded_context.text.clone(),
1198 creases: message
1199 .creases
1200 .iter()
1201 .map(|crease| SerializedCrease {
1202 start: crease.range.start,
1203 end: crease.range.end,
1204 icon_path: crease.icon_path.clone(),
1205 label: crease.label.clone(),
1206 })
1207 .collect(),
1208 is_hidden: message.is_hidden,
1209 })
1210 .collect(),
1211 initial_project_snapshot,
1212 cumulative_token_usage: this.cumulative_token_usage,
1213 request_token_usage: this.request_token_usage.clone(),
1214 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215 exceeded_window_error: this.exceeded_window_error.clone(),
1216 model: this
1217 .configured_model
1218 .as_ref()
1219 .map(|model| SerializedLanguageModel {
1220 provider: model.provider.id().0.to_string(),
1221 model: model.model.id().0.to_string(),
1222 }),
1223 completion_mode: Some(this.completion_mode),
1224 tool_use_limit_reached: this.tool_use_limit_reached,
1225 profile: Some(this.profile.id().clone()),
1226 })
1227 })
1228 }
1229
1230 pub fn remaining_turns(&self) -> u32 {
1231 self.remaining_turns
1232 }
1233
1234 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235 self.remaining_turns = remaining_turns;
1236 }
1237
1238 pub fn send_to_model(
1239 &mut self,
1240 model: Arc<dyn LanguageModel>,
1241 intent: CompletionIntent,
1242 window: Option<AnyWindowHandle>,
1243 cx: &mut Context<Self>,
1244 ) {
1245 if self.remaining_turns == 0 {
1246 return;
1247 }
1248
1249 self.remaining_turns -= 1;
1250
1251 let request = self.to_completion_request(model.clone(), intent, cx);
1252
1253 self.stream_completion(request, model, intent, window, cx);
1254 }
1255
1256 pub fn used_tools_since_last_user_message(&self) -> bool {
1257 for message in self.messages.iter().rev() {
1258 if self.tool_use.message_has_tool_results(message.id) {
1259 return true;
1260 } else if message.role == Role::User {
1261 return false;
1262 }
1263 }
1264
1265 false
1266 }
1267
1268 pub fn to_completion_request(
1269 &self,
1270 model: Arc<dyn LanguageModel>,
1271 intent: CompletionIntent,
1272 cx: &mut Context<Self>,
1273 ) -> LanguageModelRequest {
1274 let mut request = LanguageModelRequest {
1275 thread_id: Some(self.id.to_string()),
1276 prompt_id: Some(self.last_prompt_id.to_string()),
1277 intent: Some(intent),
1278 mode: None,
1279 messages: vec![],
1280 tools: Vec::new(),
1281 tool_choice: None,
1282 stop: Vec::new(),
1283 temperature: AgentSettings::temperature_for_model(&model, cx),
1284 };
1285
1286 let available_tools = self.available_tools(cx, model.clone());
1287 let available_tool_names = available_tools
1288 .iter()
1289 .map(|tool| tool.name.clone())
1290 .collect();
1291
1292 let model_context = &ModelContext {
1293 available_tools: available_tool_names,
1294 };
1295
1296 if let Some(project_context) = self.project_context.borrow().as_ref() {
1297 match self
1298 .prompt_builder
1299 .generate_assistant_system_prompt(project_context, model_context)
1300 {
1301 Err(err) => {
1302 let message = format!("{err:?}").into();
1303 log::error!("{message}");
1304 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1305 header: "Error generating system prompt".into(),
1306 message,
1307 }));
1308 }
1309 Ok(system_prompt) => {
1310 request.messages.push(LanguageModelRequestMessage {
1311 role: Role::System,
1312 content: vec![MessageContent::Text(system_prompt)],
1313 cache: true,
1314 });
1315 }
1316 }
1317 } else {
1318 let message = "Context for system prompt unexpectedly not ready.".into();
1319 log::error!("{message}");
1320 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1321 header: "Error generating system prompt".into(),
1322 message,
1323 }));
1324 }
1325
1326 let mut message_ix_to_cache = None;
1327 for message in &self.messages {
1328 // ui_only messages are for the UI only, not for the model
1329 if message.ui_only {
1330 continue;
1331 }
1332
1333 let mut request_message = LanguageModelRequestMessage {
1334 role: message.role,
1335 content: Vec::new(),
1336 cache: false,
1337 };
1338
1339 message
1340 .loaded_context
1341 .add_to_request_message(&mut request_message);
1342
1343 for segment in &message.segments {
1344 match segment {
1345 MessageSegment::Text(text) => {
1346 let text = text.trim_end();
1347 if !text.is_empty() {
1348 request_message
1349 .content
1350 .push(MessageContent::Text(text.into()));
1351 }
1352 }
1353 MessageSegment::Thinking { text, signature } => {
1354 if !text.is_empty() {
1355 request_message.content.push(MessageContent::Thinking {
1356 text: text.into(),
1357 signature: signature.clone(),
1358 });
1359 }
1360 }
1361 MessageSegment::RedactedThinking(data) => {
1362 request_message
1363 .content
1364 .push(MessageContent::RedactedThinking(data.clone()));
1365 }
1366 };
1367 }
1368
1369 let mut cache_message = true;
1370 let mut tool_results_message = LanguageModelRequestMessage {
1371 role: Role::User,
1372 content: Vec::new(),
1373 cache: false,
1374 };
1375 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1376 if let Some(tool_result) = tool_result {
1377 request_message
1378 .content
1379 .push(MessageContent::ToolUse(tool_use.clone()));
1380 tool_results_message
1381 .content
1382 .push(MessageContent::ToolResult(LanguageModelToolResult {
1383 tool_use_id: tool_use.id.clone(),
1384 tool_name: tool_result.tool_name.clone(),
1385 is_error: tool_result.is_error,
1386 content: if tool_result.content.is_empty() {
1387 // Surprisingly, the API fails if we return an empty string here.
1388 // It thinks we are sending a tool use without a tool result.
1389 "<Tool returned an empty string>".into()
1390 } else {
1391 tool_result.content.clone()
1392 },
1393 output: None,
1394 }));
1395 } else {
1396 cache_message = false;
1397 log::debug!(
1398 "skipped tool use {:?} because it is still pending",
1399 tool_use
1400 );
1401 }
1402 }
1403
1404 if cache_message {
1405 message_ix_to_cache = Some(request.messages.len());
1406 }
1407 request.messages.push(request_message);
1408
1409 if !tool_results_message.content.is_empty() {
1410 if cache_message {
1411 message_ix_to_cache = Some(request.messages.len());
1412 }
1413 request.messages.push(tool_results_message);
1414 }
1415 }
1416
1417 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1418 if let Some(message_ix_to_cache) = message_ix_to_cache {
1419 request.messages[message_ix_to_cache].cache = true;
1420 }
1421
1422 request.tools = available_tools;
1423 request.mode = if model.supports_burn_mode() {
1424 Some(self.completion_mode.into())
1425 } else {
1426 Some(CompletionMode::Normal.into())
1427 };
1428
1429 request
1430 }
1431
1432 fn to_summarize_request(
1433 &self,
1434 model: &Arc<dyn LanguageModel>,
1435 intent: CompletionIntent,
1436 added_user_message: String,
1437 cx: &App,
1438 ) -> LanguageModelRequest {
1439 let mut request = LanguageModelRequest {
1440 thread_id: None,
1441 prompt_id: None,
1442 intent: Some(intent),
1443 mode: None,
1444 messages: vec![],
1445 tools: Vec::new(),
1446 tool_choice: None,
1447 stop: Vec::new(),
1448 temperature: AgentSettings::temperature_for_model(model, cx),
1449 };
1450
1451 for message in &self.messages {
1452 let mut request_message = LanguageModelRequestMessage {
1453 role: message.role,
1454 content: Vec::new(),
1455 cache: false,
1456 };
1457
1458 for segment in &message.segments {
1459 match segment {
1460 MessageSegment::Text(text) => request_message
1461 .content
1462 .push(MessageContent::Text(text.clone())),
1463 MessageSegment::Thinking { .. } => {}
1464 MessageSegment::RedactedThinking(_) => {}
1465 }
1466 }
1467
1468 if request_message.content.is_empty() {
1469 continue;
1470 }
1471
1472 request.messages.push(request_message);
1473 }
1474
1475 request.messages.push(LanguageModelRequestMessage {
1476 role: Role::User,
1477 content: vec![MessageContent::Text(added_user_message)],
1478 cache: false,
1479 });
1480
1481 request
1482 }
1483
1484 pub fn stream_completion(
1485 &mut self,
1486 request: LanguageModelRequest,
1487 model: Arc<dyn LanguageModel>,
1488 intent: CompletionIntent,
1489 window: Option<AnyWindowHandle>,
1490 cx: &mut Context<Self>,
1491 ) {
1492 self.tool_use_limit_reached = false;
1493
1494 let pending_completion_id = post_inc(&mut self.completion_count);
1495 let mut request_callback_parameters = if self.request_callback.is_some() {
1496 Some((request.clone(), Vec::new()))
1497 } else {
1498 None
1499 };
1500 let prompt_id = self.last_prompt_id.clone();
1501 let tool_use_metadata = ToolUseMetadata {
1502 model: model.clone(),
1503 thread_id: self.id.clone(),
1504 prompt_id: prompt_id.clone(),
1505 };
1506
1507 self.last_received_chunk_at = Some(Instant::now());
1508
1509 let task = cx.spawn(async move |thread, cx| {
1510 let stream_completion_future = model.stream_completion(request, &cx);
1511 let initial_token_usage =
1512 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1513 let stream_completion = async {
1514 let mut events = stream_completion_future.await?;
1515
1516 let mut stop_reason = StopReason::EndTurn;
1517 let mut current_token_usage = TokenUsage::default();
1518
1519 thread
1520 .update(cx, |_thread, cx| {
1521 cx.emit(ThreadEvent::NewRequest);
1522 })
1523 .ok();
1524
1525 let mut request_assistant_message_id = None;
1526
1527 while let Some(event) = events.next().await {
1528 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1529 response_events
1530 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1531 }
1532
1533 thread.update(cx, |thread, cx| {
1534 let event = match event {
1535 Ok(event) => event,
1536 Err(error) => {
1537 match error {
1538 LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1539 anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1540 }
1541 LanguageModelCompletionError::Overloaded => {
1542 anyhow::bail!(LanguageModelKnownError::Overloaded);
1543 }
1544 LanguageModelCompletionError::ApiInternalServerError =>{
1545 anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1546 }
1547 LanguageModelCompletionError::PromptTooLarge { tokens } => {
1548 let tokens = tokens.unwrap_or_else(|| {
1549 // We didn't get an exact token count from the API, so fall back on our estimate.
1550 thread.total_token_usage()
1551 .map(|usage| usage.total)
1552 .unwrap_or(0)
1553 // We know the context window was exceeded in practice, so if our estimate was
1554 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1555 .max(model.max_token_count().saturating_add(1))
1556 });
1557
1558 anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1559 }
1560 LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1561 anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1562 }
1563 LanguageModelCompletionError::UnknownResponseFormat(error) => {
1564 anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1565 }
1566 LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1567 if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1568 anyhow::bail!(known_error);
1569 } else {
1570 return Err(error.into());
1571 }
1572 }
1573 LanguageModelCompletionError::DeserializeResponse(error) => {
1574 anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1575 }
1576 LanguageModelCompletionError::BadInputJson {
1577 id,
1578 tool_name,
1579 raw_input: invalid_input_json,
1580 json_parse_error,
1581 } => {
1582 thread.receive_invalid_tool_json(
1583 id,
1584 tool_name,
1585 invalid_input_json,
1586 json_parse_error,
1587 window,
1588 cx,
1589 );
1590 return Ok(());
1591 }
1592 // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1593 err @ LanguageModelCompletionError::BadRequestFormat |
1594 err @ LanguageModelCompletionError::AuthenticationError |
1595 err @ LanguageModelCompletionError::PermissionError |
1596 err @ LanguageModelCompletionError::ApiEndpointNotFound |
1597 err @ LanguageModelCompletionError::SerializeRequest(_) |
1598 err @ LanguageModelCompletionError::BuildRequestBody(_) |
1599 err @ LanguageModelCompletionError::HttpSend(_) => {
1600 anyhow::bail!(err);
1601 }
1602 LanguageModelCompletionError::Other(error) => {
1603 return Err(error);
1604 }
1605 }
1606 }
1607 };
1608
1609 match event {
1610 LanguageModelCompletionEvent::StartMessage { .. } => {
1611 request_assistant_message_id =
1612 Some(thread.insert_assistant_message(
1613 vec![MessageSegment::Text(String::new())],
1614 cx,
1615 ));
1616 }
1617 LanguageModelCompletionEvent::Stop(reason) => {
1618 stop_reason = reason;
1619 }
1620 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1621 thread.update_token_usage_at_last_message(token_usage);
1622 thread.cumulative_token_usage = thread.cumulative_token_usage
1623 + token_usage
1624 - current_token_usage;
1625 current_token_usage = token_usage;
1626 }
1627 LanguageModelCompletionEvent::Text(chunk) => {
1628 thread.received_chunk();
1629
1630 cx.emit(ThreadEvent::ReceivedTextChunk);
1631 if let Some(last_message) = thread.messages.last_mut() {
1632 if last_message.role == Role::Assistant
1633 && !thread.tool_use.has_tool_results(last_message.id)
1634 {
1635 last_message.push_text(&chunk);
1636 cx.emit(ThreadEvent::StreamedAssistantText(
1637 last_message.id,
1638 chunk,
1639 ));
1640 } else {
1641 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1642 // of a new Assistant response.
1643 //
1644 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1645 // will result in duplicating the text of the chunk in the rendered Markdown.
1646 request_assistant_message_id =
1647 Some(thread.insert_assistant_message(
1648 vec![MessageSegment::Text(chunk.to_string())],
1649 cx,
1650 ));
1651 };
1652 }
1653 }
1654 LanguageModelCompletionEvent::Thinking {
1655 text: chunk,
1656 signature,
1657 } => {
1658 thread.received_chunk();
1659
1660 if let Some(last_message) = thread.messages.last_mut() {
1661 if last_message.role == Role::Assistant
1662 && !thread.tool_use.has_tool_results(last_message.id)
1663 {
1664 last_message.push_thinking(&chunk, signature);
1665 cx.emit(ThreadEvent::StreamedAssistantThinking(
1666 last_message.id,
1667 chunk,
1668 ));
1669 } else {
1670 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1671 // of a new Assistant response.
1672 //
1673 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1674 // will result in duplicating the text of the chunk in the rendered Markdown.
1675 request_assistant_message_id =
1676 Some(thread.insert_assistant_message(
1677 vec![MessageSegment::Thinking {
1678 text: chunk.to_string(),
1679 signature,
1680 }],
1681 cx,
1682 ));
1683 };
1684 }
1685 }
1686 LanguageModelCompletionEvent::RedactedThinking {
1687 data
1688 } => {
1689 thread.received_chunk();
1690
1691 if let Some(last_message) = thread.messages.last_mut() {
1692 if last_message.role == Role::Assistant
1693 && !thread.tool_use.has_tool_results(last_message.id)
1694 {
1695 last_message.push_redacted_thinking(data);
1696 } else {
1697 request_assistant_message_id =
1698 Some(thread.insert_assistant_message(
1699 vec![MessageSegment::RedactedThinking(data)],
1700 cx,
1701 ));
1702 };
1703 }
1704 }
1705 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1706 let last_assistant_message_id = request_assistant_message_id
1707 .unwrap_or_else(|| {
1708 let new_assistant_message_id =
1709 thread.insert_assistant_message(vec![], cx);
1710 request_assistant_message_id =
1711 Some(new_assistant_message_id);
1712 new_assistant_message_id
1713 });
1714
1715 let tool_use_id = tool_use.id.clone();
1716 let streamed_input = if tool_use.is_input_complete {
1717 None
1718 } else {
1719 Some((&tool_use.input).clone())
1720 };
1721
1722 let ui_text = thread.tool_use.request_tool_use(
1723 last_assistant_message_id,
1724 tool_use,
1725 tool_use_metadata.clone(),
1726 cx,
1727 );
1728
1729 if let Some(input) = streamed_input {
1730 cx.emit(ThreadEvent::StreamedToolUse {
1731 tool_use_id,
1732 ui_text,
1733 input,
1734 });
1735 }
1736 }
1737 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1738 if let Some(completion) = thread
1739 .pending_completions
1740 .iter_mut()
1741 .find(|completion| completion.id == pending_completion_id)
1742 {
1743 match status_update {
1744 CompletionRequestStatus::Queued {
1745 position,
1746 } => {
1747 completion.queue_state = QueueState::Queued { position };
1748 }
1749 CompletionRequestStatus::Started => {
1750 completion.queue_state = QueueState::Started;
1751 }
1752 CompletionRequestStatus::Failed {
1753 code, message, request_id
1754 } => {
1755 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1756 }
1757 CompletionRequestStatus::UsageUpdated {
1758 amount, limit
1759 } => {
1760 thread.update_model_request_usage(amount as u32, limit, cx);
1761 }
1762 CompletionRequestStatus::ToolUseLimitReached => {
1763 thread.tool_use_limit_reached = true;
1764 cx.emit(ThreadEvent::ToolUseLimitReached);
1765 }
1766 }
1767 }
1768 }
1769 }
1770
1771 thread.touch_updated_at();
1772 cx.emit(ThreadEvent::StreamedCompletion);
1773 cx.notify();
1774
1775 thread.auto_capture_telemetry(cx);
1776 Ok(())
1777 })??;
1778
1779 smol::future::yield_now().await;
1780 }
1781
1782 thread.update(cx, |thread, cx| {
1783 thread.last_received_chunk_at = None;
1784 thread
1785 .pending_completions
1786 .retain(|completion| completion.id != pending_completion_id);
1787
1788 // If there is a response without tool use, summarize the message. Otherwise,
1789 // allow two tool uses before summarizing.
1790 if matches!(thread.summary, ThreadSummary::Pending)
1791 && thread.messages.len() >= 2
1792 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1793 {
1794 thread.summarize(cx);
1795 }
1796 })?;
1797
1798 anyhow::Ok(stop_reason)
1799 };
1800
1801 let result = stream_completion.await;
1802 let mut retry_scheduled = false;
1803
1804 thread
1805 .update(cx, |thread, cx| {
1806 thread.finalize_pending_checkpoint(cx);
1807 match result.as_ref() {
1808 Ok(stop_reason) => {
1809 match stop_reason {
1810 StopReason::ToolUse => {
1811 let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
1812 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1813 }
1814 StopReason::EndTurn | StopReason::MaxTokens => {
1815 thread.project.update(cx, |project, cx| {
1816 project.set_agent_location(None, cx);
1817 });
1818 }
1819 StopReason::Refusal => {
1820 thread.project.update(cx, |project, cx| {
1821 project.set_agent_location(None, cx);
1822 });
1823
1824 // Remove the turn that was refused.
1825 //
1826 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1827 {
1828 let mut messages_to_remove = Vec::new();
1829
1830 for (ix, message) in thread.messages.iter().enumerate().rev() {
1831 messages_to_remove.push(message.id);
1832
1833 if message.role == Role::User {
1834 if ix == 0 {
1835 break;
1836 }
1837
1838 if let Some(prev_message) = thread.messages.get(ix - 1) {
1839 if prev_message.role == Role::Assistant {
1840 break;
1841 }
1842 }
1843 }
1844 }
1845
1846 for message_id in messages_to_remove {
1847 thread.delete_message(message_id, cx);
1848 }
1849 }
1850
1851 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1852 header: "Language model refusal".into(),
1853 message: "Model refused to generate content for safety reasons.".into(),
1854 }));
1855 }
1856 }
1857
1858 // We successfully completed, so cancel any remaining retries.
1859 thread.retry_state = None;
1860 },
1861 Err(error) => {
1862 thread.project.update(cx, |project, cx| {
1863 project.set_agent_location(None, cx);
1864 });
1865
1866 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1867 let error_message = error
1868 .chain()
1869 .map(|err| err.to_string())
1870 .collect::<Vec<_>>()
1871 .join("\n");
1872 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1873 header: "Error interacting with language model".into(),
1874 message: SharedString::from(error_message.clone()),
1875 }));
1876 }
1877
1878 if error.is::<PaymentRequiredError>() {
1879 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1880 } else if let Some(error) =
1881 error.downcast_ref::<ModelRequestLimitReachedError>()
1882 {
1883 cx.emit(ThreadEvent::ShowError(
1884 ThreadError::ModelRequestLimitReached { plan: error.plan },
1885 ));
1886 } else if let Some(known_error) =
1887 error.downcast_ref::<LanguageModelKnownError>()
1888 {
1889 match known_error {
1890 LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1891 thread.exceeded_window_error = Some(ExceededWindowError {
1892 model_id: model.id(),
1893 token_count: *tokens,
1894 });
1895 cx.notify();
1896 }
1897 LanguageModelKnownError::RateLimitExceeded { retry_after } => {
1898 let provider_name = model.provider_name();
1899 let error_message = format!(
1900 "{}'s API rate limit exceeded",
1901 provider_name.0.as_ref()
1902 );
1903
1904 thread.handle_rate_limit_error(
1905 &error_message,
1906 *retry_after,
1907 model.clone(),
1908 intent,
1909 window,
1910 cx,
1911 );
1912 retry_scheduled = true;
1913 }
1914 LanguageModelKnownError::Overloaded => {
1915 let provider_name = model.provider_name();
1916 let error_message = format!(
1917 "{}'s API servers are overloaded right now",
1918 provider_name.0.as_ref()
1919 );
1920
1921 retry_scheduled = thread.handle_retryable_error(
1922 &error_message,
1923 model.clone(),
1924 intent,
1925 window,
1926 cx,
1927 );
1928 if !retry_scheduled {
1929 emit_generic_error(error, cx);
1930 }
1931 }
1932 LanguageModelKnownError::ApiInternalServerError => {
1933 let provider_name = model.provider_name();
1934 let error_message = format!(
1935 "{}'s API server reported an internal server error",
1936 provider_name.0.as_ref()
1937 );
1938
1939 retry_scheduled = thread.handle_retryable_error(
1940 &error_message,
1941 model.clone(),
1942 intent,
1943 window,
1944 cx,
1945 );
1946 if !retry_scheduled {
1947 emit_generic_error(error, cx);
1948 }
1949 }
1950 LanguageModelKnownError::ReadResponseError(_) |
1951 LanguageModelKnownError::DeserializeResponse(_) |
1952 LanguageModelKnownError::UnknownResponseFormat(_) => {
1953 // In the future we will attempt to re-roll response, but only once
1954 emit_generic_error(error, cx);
1955 }
1956 }
1957 } else {
1958 emit_generic_error(error, cx);
1959 }
1960
1961 if !retry_scheduled {
1962 thread.cancel_last_completion(window, cx);
1963 }
1964 }
1965 }
1966
1967 if !retry_scheduled {
1968 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1969 }
1970
1971 if let Some((request_callback, (request, response_events))) = thread
1972 .request_callback
1973 .as_mut()
1974 .zip(request_callback_parameters.as_ref())
1975 {
1976 request_callback(request, response_events);
1977 }
1978
1979 thread.auto_capture_telemetry(cx);
1980
1981 if let Ok(initial_usage) = initial_token_usage {
1982 let usage = thread.cumulative_token_usage - initial_usage;
1983
1984 telemetry::event!(
1985 "Assistant Thread Completion",
1986 thread_id = thread.id().to_string(),
1987 prompt_id = prompt_id,
1988 model = model.telemetry_id(),
1989 model_provider = model.provider_id().to_string(),
1990 input_tokens = usage.input_tokens,
1991 output_tokens = usage.output_tokens,
1992 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1993 cache_read_input_tokens = usage.cache_read_input_tokens,
1994 );
1995 }
1996 })
1997 .ok();
1998 });
1999
2000 self.pending_completions.push(PendingCompletion {
2001 id: pending_completion_id,
2002 queue_state: QueueState::Sending,
2003 _task: task,
2004 });
2005 }
2006
2007 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2008 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2009 println!("No thread summary model");
2010 return;
2011 };
2012
2013 if !model.provider.is_authenticated(cx) {
2014 return;
2015 }
2016
2017 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2018
2019 let request = self.to_summarize_request(
2020 &model.model,
2021 CompletionIntent::ThreadSummarization,
2022 added_user_message.into(),
2023 cx,
2024 );
2025
2026 self.summary = ThreadSummary::Generating;
2027
2028 self.pending_summary = cx.spawn(async move |this, cx| {
2029 let result = async {
2030 let mut messages = model.model.stream_completion(request, &cx).await?;
2031
2032 let mut new_summary = String::new();
2033 while let Some(event) = messages.next().await {
2034 let Ok(event) = event else {
2035 continue;
2036 };
2037 let text = match event {
2038 LanguageModelCompletionEvent::Text(text) => text,
2039 LanguageModelCompletionEvent::StatusUpdate(
2040 CompletionRequestStatus::UsageUpdated { amount, limit },
2041 ) => {
2042 this.update(cx, |thread, cx| {
2043 thread.update_model_request_usage(amount as u32, limit, cx);
2044 })?;
2045 continue;
2046 }
2047 _ => continue,
2048 };
2049
2050 let mut lines = text.lines();
2051 new_summary.extend(lines.next());
2052
2053 // Stop if the LLM generated multiple lines.
2054 if lines.next().is_some() {
2055 break;
2056 }
2057 }
2058
2059 anyhow::Ok(new_summary)
2060 }
2061 .await;
2062
2063 this.update(cx, |this, cx| {
2064 match result {
2065 Ok(new_summary) => {
2066 if new_summary.is_empty() {
2067 this.summary = ThreadSummary::Error;
2068 } else {
2069 this.summary = ThreadSummary::Ready(new_summary.into());
2070 }
2071 }
2072 Err(err) => {
2073 this.summary = ThreadSummary::Error;
2074 log::error!("Failed to generate thread summary: {}", err);
2075 }
2076 }
2077 cx.emit(ThreadEvent::SummaryGenerated);
2078 })
2079 .log_err()?;
2080
2081 Some(())
2082 });
2083 }
2084
2085 fn handle_rate_limit_error(
2086 &mut self,
2087 error_message: &str,
2088 retry_after: Duration,
2089 model: Arc<dyn LanguageModel>,
2090 intent: CompletionIntent,
2091 window: Option<AnyWindowHandle>,
2092 cx: &mut Context<Self>,
2093 ) {
2094 // For rate limit errors, we only retry once with the specified duration
2095 let retry_message = format!(
2096 "{error_message}. Retrying in {} seconds…",
2097 retry_after.as_secs()
2098 );
2099
2100 // Add a UI-only message instead of a regular message
2101 let id = self.next_message_id.post_inc();
2102 self.messages.push(Message {
2103 id,
2104 role: Role::System,
2105 segments: vec![MessageSegment::Text(retry_message)],
2106 loaded_context: LoadedContext::default(),
2107 creases: Vec::new(),
2108 is_hidden: false,
2109 ui_only: true,
2110 });
2111 cx.emit(ThreadEvent::MessageAdded(id));
2112 // Schedule the retry
2113 let thread_handle = cx.entity().downgrade();
2114
2115 cx.spawn(async move |_thread, cx| {
2116 cx.background_executor().timer(retry_after).await;
2117
2118 thread_handle
2119 .update(cx, |thread, cx| {
2120 // Retry the completion
2121 thread.send_to_model(model, intent, window, cx);
2122 })
2123 .log_err();
2124 })
2125 .detach();
2126 }
2127
2128 fn handle_retryable_error(
2129 &mut self,
2130 error_message: &str,
2131 model: Arc<dyn LanguageModel>,
2132 intent: CompletionIntent,
2133 window: Option<AnyWindowHandle>,
2134 cx: &mut Context<Self>,
2135 ) -> bool {
2136 self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
2137 }
2138
2139 fn handle_retryable_error_with_delay(
2140 &mut self,
2141 error_message: &str,
2142 custom_delay: Option<Duration>,
2143 model: Arc<dyn LanguageModel>,
2144 intent: CompletionIntent,
2145 window: Option<AnyWindowHandle>,
2146 cx: &mut Context<Self>,
2147 ) -> bool {
2148 let retry_state = self.retry_state.get_or_insert(RetryState {
2149 attempt: 0,
2150 max_attempts: MAX_RETRY_ATTEMPTS,
2151 intent,
2152 });
2153
2154 retry_state.attempt += 1;
2155 let attempt = retry_state.attempt;
2156 let max_attempts = retry_state.max_attempts;
2157 let intent = retry_state.intent;
2158
2159 if attempt <= max_attempts {
2160 // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2161 let delay = if let Some(custom_delay) = custom_delay {
2162 custom_delay
2163 } else {
2164 let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2165 Duration::from_secs(delay_secs)
2166 };
2167
2168 // Add a transient message to inform the user
2169 let delay_secs = delay.as_secs();
2170 let retry_message = format!(
2171 "{}. Retrying (attempt {} of {}) in {} seconds...",
2172 error_message, attempt, max_attempts, delay_secs
2173 );
2174
2175 // Add a UI-only message instead of a regular message
2176 let id = self.next_message_id.post_inc();
2177 self.messages.push(Message {
2178 id,
2179 role: Role::System,
2180 segments: vec![MessageSegment::Text(retry_message)],
2181 loaded_context: LoadedContext::default(),
2182 creases: Vec::new(),
2183 is_hidden: false,
2184 ui_only: true,
2185 });
2186 cx.emit(ThreadEvent::MessageAdded(id));
2187
2188 // Schedule the retry
2189 let thread_handle = cx.entity().downgrade();
2190
2191 cx.spawn(async move |_thread, cx| {
2192 cx.background_executor().timer(delay).await;
2193
2194 thread_handle
2195 .update(cx, |thread, cx| {
2196 // Retry the completion
2197 thread.send_to_model(model, intent, window, cx);
2198 })
2199 .log_err();
2200 })
2201 .detach();
2202
2203 true
2204 } else {
2205 // Max retries exceeded
2206 self.retry_state = None;
2207
2208 let notification_text = if max_attempts == 1 {
2209 "Failed after retrying.".into()
2210 } else {
2211 format!("Failed after retrying {} times.", max_attempts).into()
2212 };
2213
2214 // Stop generating since we're giving up on retrying.
2215 self.pending_completions.clear();
2216
2217 cx.emit(ThreadEvent::RetriesFailed {
2218 message: notification_text,
2219 });
2220
2221 false
2222 }
2223 }
2224
2225 pub fn start_generating_detailed_summary_if_needed(
2226 &mut self,
2227 thread_store: WeakEntity<ThreadStore>,
2228 cx: &mut Context<Self>,
2229 ) {
2230 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2231 return;
2232 };
2233
2234 match &*self.detailed_summary_rx.borrow() {
2235 DetailedSummaryState::Generating { message_id, .. }
2236 | DetailedSummaryState::Generated { message_id, .. }
2237 if *message_id == last_message_id =>
2238 {
2239 // Already up-to-date
2240 return;
2241 }
2242 _ => {}
2243 }
2244
2245 let Some(ConfiguredModel { model, provider }) =
2246 LanguageModelRegistry::read_global(cx).thread_summary_model()
2247 else {
2248 return;
2249 };
2250
2251 if !provider.is_authenticated(cx) {
2252 return;
2253 }
2254
2255 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2256
2257 let request = self.to_summarize_request(
2258 &model,
2259 CompletionIntent::ThreadContextSummarization,
2260 added_user_message.into(),
2261 cx,
2262 );
2263
2264 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2265 message_id: last_message_id,
2266 };
2267
2268 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2269 // be better to allow the old task to complete, but this would require logic for choosing
2270 // which result to prefer (the old task could complete after the new one, resulting in a
2271 // stale summary).
2272 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2273 let stream = model.stream_completion_text(request, &cx);
2274 let Some(mut messages) = stream.await.log_err() else {
2275 thread
2276 .update(cx, |thread, _cx| {
2277 *thread.detailed_summary_tx.borrow_mut() =
2278 DetailedSummaryState::NotGenerated;
2279 })
2280 .ok()?;
2281 return None;
2282 };
2283
2284 let mut new_detailed_summary = String::new();
2285
2286 while let Some(chunk) = messages.stream.next().await {
2287 if let Some(chunk) = chunk.log_err() {
2288 new_detailed_summary.push_str(&chunk);
2289 }
2290 }
2291
2292 thread
2293 .update(cx, |thread, _cx| {
2294 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2295 text: new_detailed_summary.into(),
2296 message_id: last_message_id,
2297 };
2298 })
2299 .ok()?;
2300
2301 // Save thread so its summary can be reused later
2302 if let Some(thread) = thread.upgrade() {
2303 if let Ok(Ok(save_task)) = cx.update(|cx| {
2304 thread_store
2305 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2306 }) {
2307 save_task.await.log_err();
2308 }
2309 }
2310
2311 Some(())
2312 });
2313 }
2314
2315 pub async fn wait_for_detailed_summary_or_text(
2316 this: &Entity<Self>,
2317 cx: &mut AsyncApp,
2318 ) -> Option<SharedString> {
2319 let mut detailed_summary_rx = this
2320 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2321 .ok()?;
2322 loop {
2323 match detailed_summary_rx.recv().await? {
2324 DetailedSummaryState::Generating { .. } => {}
2325 DetailedSummaryState::NotGenerated => {
2326 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2327 }
2328 DetailedSummaryState::Generated { text, .. } => return Some(text),
2329 }
2330 }
2331 }
2332
2333 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2334 self.detailed_summary_rx
2335 .borrow()
2336 .text()
2337 .unwrap_or_else(|| self.text().into())
2338 }
2339
2340 pub fn is_generating_detailed_summary(&self) -> bool {
2341 matches!(
2342 &*self.detailed_summary_rx.borrow(),
2343 DetailedSummaryState::Generating { .. }
2344 )
2345 }
2346
2347 pub fn use_pending_tools(
2348 &mut self,
2349 window: Option<AnyWindowHandle>,
2350 model: Arc<dyn LanguageModel>,
2351 cx: &mut Context<Self>,
2352 ) -> Vec<PendingToolUse> {
2353 self.auto_capture_telemetry(cx);
2354 let request =
2355 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2356 let pending_tool_uses = self
2357 .tool_use
2358 .pending_tool_uses()
2359 .into_iter()
2360 .filter(|tool_use| tool_use.status.is_idle())
2361 .cloned()
2362 .collect::<Vec<_>>();
2363
2364 for tool_use in pending_tool_uses.iter() {
2365 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2366 }
2367
2368 pending_tool_uses
2369 }
2370
2371 fn use_pending_tool(
2372 &mut self,
2373 tool_use: PendingToolUse,
2374 request: Arc<LanguageModelRequest>,
2375 model: Arc<dyn LanguageModel>,
2376 window: Option<AnyWindowHandle>,
2377 cx: &mut Context<Self>,
2378 ) {
2379 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2380 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2381 };
2382
2383 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2384 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2385 }
2386
2387 if tool.needs_confirmation(&tool_use.input, cx)
2388 && !AgentSettings::get_global(cx).always_allow_tool_actions
2389 {
2390 self.tool_use.confirm_tool_use(
2391 tool_use.id,
2392 tool_use.ui_text,
2393 tool_use.input,
2394 request,
2395 tool,
2396 );
2397 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2398 } else {
2399 self.run_tool(
2400 tool_use.id,
2401 tool_use.ui_text,
2402 tool_use.input,
2403 request,
2404 tool,
2405 model,
2406 window,
2407 cx,
2408 );
2409 }
2410 }
2411
2412 pub fn handle_hallucinated_tool_use(
2413 &mut self,
2414 tool_use_id: LanguageModelToolUseId,
2415 hallucinated_tool_name: Arc<str>,
2416 window: Option<AnyWindowHandle>,
2417 cx: &mut Context<Thread>,
2418 ) {
2419 let available_tools = self.profile.enabled_tools(cx);
2420
2421 let tool_list = available_tools
2422 .iter()
2423 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2424 .collect::<Vec<_>>()
2425 .join("\n");
2426
2427 let error_message = format!(
2428 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2429 hallucinated_tool_name, tool_list
2430 );
2431
2432 let pending_tool_use = self.tool_use.insert_tool_output(
2433 tool_use_id.clone(),
2434 hallucinated_tool_name,
2435 Err(anyhow!("Missing tool call: {error_message}")),
2436 self.configured_model.as_ref(),
2437 );
2438
2439 cx.emit(ThreadEvent::MissingToolUse {
2440 tool_use_id: tool_use_id.clone(),
2441 ui_text: error_message.into(),
2442 });
2443
2444 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2445 }
2446
2447 pub fn receive_invalid_tool_json(
2448 &mut self,
2449 tool_use_id: LanguageModelToolUseId,
2450 tool_name: Arc<str>,
2451 invalid_json: Arc<str>,
2452 error: String,
2453 window: Option<AnyWindowHandle>,
2454 cx: &mut Context<Thread>,
2455 ) {
2456 log::error!("The model returned invalid input JSON: {invalid_json}");
2457
2458 let pending_tool_use = self.tool_use.insert_tool_output(
2459 tool_use_id.clone(),
2460 tool_name,
2461 Err(anyhow!("Error parsing input JSON: {error}")),
2462 self.configured_model.as_ref(),
2463 );
2464 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2465 pending_tool_use.ui_text.clone()
2466 } else {
2467 log::error!(
2468 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2469 );
2470 format!("Unknown tool {}", tool_use_id).into()
2471 };
2472
2473 cx.emit(ThreadEvent::InvalidToolInput {
2474 tool_use_id: tool_use_id.clone(),
2475 ui_text,
2476 invalid_input_json: invalid_json,
2477 });
2478
2479 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2480 }
2481
2482 pub fn run_tool(
2483 &mut self,
2484 tool_use_id: LanguageModelToolUseId,
2485 ui_text: impl Into<SharedString>,
2486 input: serde_json::Value,
2487 request: Arc<LanguageModelRequest>,
2488 tool: Arc<dyn Tool>,
2489 model: Arc<dyn LanguageModel>,
2490 window: Option<AnyWindowHandle>,
2491 cx: &mut Context<Thread>,
2492 ) {
2493 let task =
2494 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2495 self.tool_use
2496 .run_pending_tool(tool_use_id, ui_text.into(), task);
2497 }
2498
2499 fn spawn_tool_use(
2500 &mut self,
2501 tool_use_id: LanguageModelToolUseId,
2502 request: Arc<LanguageModelRequest>,
2503 input: serde_json::Value,
2504 tool: Arc<dyn Tool>,
2505 model: Arc<dyn LanguageModel>,
2506 window: Option<AnyWindowHandle>,
2507 cx: &mut Context<Thread>,
2508 ) -> Task<()> {
2509 let tool_name: Arc<str> = tool.name().into();
2510
2511 let tool_result = tool.run(
2512 input,
2513 request,
2514 self.project.clone(),
2515 self.action_log.clone(),
2516 model,
2517 window,
2518 cx,
2519 );
2520
2521 // Store the card separately if it exists
2522 if let Some(card) = tool_result.card.clone() {
2523 self.tool_use
2524 .insert_tool_result_card(tool_use_id.clone(), card);
2525 }
2526
2527 cx.spawn({
2528 async move |thread: WeakEntity<Thread>, cx| {
2529 let output = tool_result.output.await;
2530
2531 thread
2532 .update(cx, |thread, cx| {
2533 let pending_tool_use = thread.tool_use.insert_tool_output(
2534 tool_use_id.clone(),
2535 tool_name,
2536 output,
2537 thread.configured_model.as_ref(),
2538 );
2539 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2540 })
2541 .ok();
2542 }
2543 })
2544 }
2545
2546 fn tool_finished(
2547 &mut self,
2548 tool_use_id: LanguageModelToolUseId,
2549 pending_tool_use: Option<PendingToolUse>,
2550 canceled: bool,
2551 window: Option<AnyWindowHandle>,
2552 cx: &mut Context<Self>,
2553 ) {
2554 if self.all_tools_finished() {
2555 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2556 if !canceled {
2557 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2558 }
2559 self.auto_capture_telemetry(cx);
2560 }
2561 }
2562
2563 cx.emit(ThreadEvent::ToolFinished {
2564 tool_use_id,
2565 pending_tool_use,
2566 });
2567 }
2568
2569 /// Cancels the last pending completion, if there are any pending.
2570 ///
2571 /// Returns whether a completion was canceled.
2572 pub fn cancel_last_completion(
2573 &mut self,
2574 window: Option<AnyWindowHandle>,
2575 cx: &mut Context<Self>,
2576 ) -> bool {
2577 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2578
2579 self.retry_state = None;
2580
2581 for pending_tool_use in self.tool_use.cancel_pending() {
2582 canceled = true;
2583 self.tool_finished(
2584 pending_tool_use.id.clone(),
2585 Some(pending_tool_use),
2586 true,
2587 window,
2588 cx,
2589 );
2590 }
2591
2592 if canceled {
2593 cx.emit(ThreadEvent::CompletionCanceled);
2594
2595 // When canceled, we always want to insert the checkpoint.
2596 // (We skip over finalize_pending_checkpoint, because it
2597 // would conclude we didn't have anything to insert here.)
2598 if let Some(checkpoint) = self.pending_checkpoint.take() {
2599 self.insert_checkpoint(checkpoint, cx);
2600 }
2601 } else {
2602 self.finalize_pending_checkpoint(cx);
2603 }
2604
2605 canceled
2606 }
2607
2608 /// Signals that any in-progress editing should be canceled.
2609 ///
2610 /// This method is used to notify listeners (like ActiveThread) that
2611 /// they should cancel any editing operations.
2612 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2613 cx.emit(ThreadEvent::CancelEditing);
2614 }
2615
2616 pub fn feedback(&self) -> Option<ThreadFeedback> {
2617 self.feedback
2618 }
2619
2620 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2621 self.message_feedback.get(&message_id).copied()
2622 }
2623
2624 pub fn report_message_feedback(
2625 &mut self,
2626 message_id: MessageId,
2627 feedback: ThreadFeedback,
2628 cx: &mut Context<Self>,
2629 ) -> Task<Result<()>> {
2630 if self.message_feedback.get(&message_id) == Some(&feedback) {
2631 return Task::ready(Ok(()));
2632 }
2633
2634 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2635 let serialized_thread = self.serialize(cx);
2636 let thread_id = self.id().clone();
2637 let client = self.project.read(cx).client();
2638
2639 let enabled_tool_names: Vec<String> = self
2640 .profile
2641 .enabled_tools(cx)
2642 .iter()
2643 .map(|tool| tool.name())
2644 .collect();
2645
2646 self.message_feedback.insert(message_id, feedback);
2647
2648 cx.notify();
2649
2650 let message_content = self
2651 .message(message_id)
2652 .map(|msg| msg.to_string())
2653 .unwrap_or_default();
2654
2655 cx.background_spawn(async move {
2656 let final_project_snapshot = final_project_snapshot.await;
2657 let serialized_thread = serialized_thread.await?;
2658 let thread_data =
2659 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2660
2661 let rating = match feedback {
2662 ThreadFeedback::Positive => "positive",
2663 ThreadFeedback::Negative => "negative",
2664 };
2665 telemetry::event!(
2666 "Assistant Thread Rated",
2667 rating,
2668 thread_id,
2669 enabled_tool_names,
2670 message_id = message_id.0,
2671 message_content,
2672 thread_data,
2673 final_project_snapshot
2674 );
2675 client.telemetry().flush_events().await;
2676
2677 Ok(())
2678 })
2679 }
2680
2681 pub fn report_feedback(
2682 &mut self,
2683 feedback: ThreadFeedback,
2684 cx: &mut Context<Self>,
2685 ) -> Task<Result<()>> {
2686 let last_assistant_message_id = self
2687 .messages
2688 .iter()
2689 .rev()
2690 .find(|msg| msg.role == Role::Assistant)
2691 .map(|msg| msg.id);
2692
2693 if let Some(message_id) = last_assistant_message_id {
2694 self.report_message_feedback(message_id, feedback, cx)
2695 } else {
2696 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2697 let serialized_thread = self.serialize(cx);
2698 let thread_id = self.id().clone();
2699 let client = self.project.read(cx).client();
2700 self.feedback = Some(feedback);
2701 cx.notify();
2702
2703 cx.background_spawn(async move {
2704 let final_project_snapshot = final_project_snapshot.await;
2705 let serialized_thread = serialized_thread.await?;
2706 let thread_data = serde_json::to_value(serialized_thread)
2707 .unwrap_or_else(|_| serde_json::Value::Null);
2708
2709 let rating = match feedback {
2710 ThreadFeedback::Positive => "positive",
2711 ThreadFeedback::Negative => "negative",
2712 };
2713 telemetry::event!(
2714 "Assistant Thread Rated",
2715 rating,
2716 thread_id,
2717 thread_data,
2718 final_project_snapshot
2719 );
2720 client.telemetry().flush_events().await;
2721
2722 Ok(())
2723 })
2724 }
2725 }
2726
2727 /// Create a snapshot of the current project state including git information and unsaved buffers.
2728 fn project_snapshot(
2729 project: Entity<Project>,
2730 cx: &mut Context<Self>,
2731 ) -> Task<Arc<ProjectSnapshot>> {
2732 let git_store = project.read(cx).git_store().clone();
2733 let worktree_snapshots: Vec<_> = project
2734 .read(cx)
2735 .visible_worktrees(cx)
2736 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2737 .collect();
2738
2739 cx.spawn(async move |_, cx| {
2740 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2741
2742 let mut unsaved_buffers = Vec::new();
2743 cx.update(|app_cx| {
2744 let buffer_store = project.read(app_cx).buffer_store();
2745 for buffer_handle in buffer_store.read(app_cx).buffers() {
2746 let buffer = buffer_handle.read(app_cx);
2747 if buffer.is_dirty() {
2748 if let Some(file) = buffer.file() {
2749 let path = file.path().to_string_lossy().to_string();
2750 unsaved_buffers.push(path);
2751 }
2752 }
2753 }
2754 })
2755 .ok();
2756
2757 Arc::new(ProjectSnapshot {
2758 worktree_snapshots,
2759 unsaved_buffer_paths: unsaved_buffers,
2760 timestamp: Utc::now(),
2761 })
2762 })
2763 }
2764
2765 fn worktree_snapshot(
2766 worktree: Entity<project::Worktree>,
2767 git_store: Entity<GitStore>,
2768 cx: &App,
2769 ) -> Task<WorktreeSnapshot> {
2770 cx.spawn(async move |cx| {
2771 // Get worktree path and snapshot
2772 let worktree_info = cx.update(|app_cx| {
2773 let worktree = worktree.read(app_cx);
2774 let path = worktree.abs_path().to_string_lossy().to_string();
2775 let snapshot = worktree.snapshot();
2776 (path, snapshot)
2777 });
2778
2779 let Ok((worktree_path, _snapshot)) = worktree_info else {
2780 return WorktreeSnapshot {
2781 worktree_path: String::new(),
2782 git_state: None,
2783 };
2784 };
2785
2786 let git_state = git_store
2787 .update(cx, |git_store, cx| {
2788 git_store
2789 .repositories()
2790 .values()
2791 .find(|repo| {
2792 repo.read(cx)
2793 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2794 .is_some()
2795 })
2796 .cloned()
2797 })
2798 .ok()
2799 .flatten()
2800 .map(|repo| {
2801 repo.update(cx, |repo, _| {
2802 let current_branch =
2803 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2804 repo.send_job(None, |state, _| async move {
2805 let RepositoryState::Local { backend, .. } = state else {
2806 return GitState {
2807 remote_url: None,
2808 head_sha: None,
2809 current_branch,
2810 diff: None,
2811 };
2812 };
2813
2814 let remote_url = backend.remote_url("origin");
2815 let head_sha = backend.head_sha().await;
2816 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2817
2818 GitState {
2819 remote_url,
2820 head_sha,
2821 current_branch,
2822 diff,
2823 }
2824 })
2825 })
2826 });
2827
2828 let git_state = match git_state {
2829 Some(git_state) => match git_state.ok() {
2830 Some(git_state) => git_state.await.ok(),
2831 None => None,
2832 },
2833 None => None,
2834 };
2835
2836 WorktreeSnapshot {
2837 worktree_path,
2838 git_state,
2839 }
2840 })
2841 }
2842
2843 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2844 let mut markdown = Vec::new();
2845
2846 let summary = self.summary().or_default();
2847 writeln!(markdown, "# {summary}\n")?;
2848
2849 for message in self.messages() {
2850 writeln!(
2851 markdown,
2852 "## {role}\n",
2853 role = match message.role {
2854 Role::User => "User",
2855 Role::Assistant => "Agent",
2856 Role::System => "System",
2857 }
2858 )?;
2859
2860 if !message.loaded_context.text.is_empty() {
2861 writeln!(markdown, "{}", message.loaded_context.text)?;
2862 }
2863
2864 if !message.loaded_context.images.is_empty() {
2865 writeln!(
2866 markdown,
2867 "\n{} images attached as context.\n",
2868 message.loaded_context.images.len()
2869 )?;
2870 }
2871
2872 for segment in &message.segments {
2873 match segment {
2874 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2875 MessageSegment::Thinking { text, .. } => {
2876 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2877 }
2878 MessageSegment::RedactedThinking(_) => {}
2879 }
2880 }
2881
2882 for tool_use in self.tool_uses_for_message(message.id, cx) {
2883 writeln!(
2884 markdown,
2885 "**Use Tool: {} ({})**",
2886 tool_use.name, tool_use.id
2887 )?;
2888 writeln!(markdown, "```json")?;
2889 writeln!(
2890 markdown,
2891 "{}",
2892 serde_json::to_string_pretty(&tool_use.input)?
2893 )?;
2894 writeln!(markdown, "```")?;
2895 }
2896
2897 for tool_result in self.tool_results_for_message(message.id) {
2898 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2899 if tool_result.is_error {
2900 write!(markdown, " (Error)")?;
2901 }
2902
2903 writeln!(markdown, "**\n")?;
2904 match &tool_result.content {
2905 LanguageModelToolResultContent::Text(text) => {
2906 writeln!(markdown, "{text}")?;
2907 }
2908 LanguageModelToolResultContent::Image(image) => {
2909 writeln!(markdown, "", image.source)?;
2910 }
2911 }
2912
2913 if let Some(output) = tool_result.output.as_ref() {
2914 writeln!(
2915 markdown,
2916 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2917 serde_json::to_string_pretty(output)?
2918 )?;
2919 }
2920 }
2921 }
2922
2923 Ok(String::from_utf8_lossy(&markdown).to_string())
2924 }
2925
2926 pub fn keep_edits_in_range(
2927 &mut self,
2928 buffer: Entity<language::Buffer>,
2929 buffer_range: Range<language::Anchor>,
2930 cx: &mut Context<Self>,
2931 ) {
2932 self.action_log.update(cx, |action_log, cx| {
2933 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2934 });
2935 }
2936
2937 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2938 self.action_log
2939 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2940 }
2941
2942 pub fn reject_edits_in_ranges(
2943 &mut self,
2944 buffer: Entity<language::Buffer>,
2945 buffer_ranges: Vec<Range<language::Anchor>>,
2946 cx: &mut Context<Self>,
2947 ) -> Task<Result<()>> {
2948 self.action_log.update(cx, |action_log, cx| {
2949 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2950 })
2951 }
2952
2953 pub fn action_log(&self) -> &Entity<ActionLog> {
2954 &self.action_log
2955 }
2956
2957 pub fn project(&self) -> &Entity<Project> {
2958 &self.project
2959 }
2960
2961 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2962 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2963 return;
2964 }
2965
2966 let now = Instant::now();
2967 if let Some(last) = self.last_auto_capture_at {
2968 if now.duration_since(last).as_secs() < 10 {
2969 return;
2970 }
2971 }
2972
2973 self.last_auto_capture_at = Some(now);
2974
2975 let thread_id = self.id().clone();
2976 let github_login = self
2977 .project
2978 .read(cx)
2979 .user_store()
2980 .read(cx)
2981 .current_user()
2982 .map(|user| user.github_login.clone());
2983 let client = self.project.read(cx).client();
2984 let serialize_task = self.serialize(cx);
2985
2986 cx.background_executor()
2987 .spawn(async move {
2988 if let Ok(serialized_thread) = serialize_task.await {
2989 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2990 telemetry::event!(
2991 "Agent Thread Auto-Captured",
2992 thread_id = thread_id.to_string(),
2993 thread_data = thread_data,
2994 auto_capture_reason = "tracked_user",
2995 github_login = github_login
2996 );
2997
2998 client.telemetry().flush_events().await;
2999 }
3000 }
3001 })
3002 .detach();
3003 }
3004
3005 pub fn cumulative_token_usage(&self) -> TokenUsage {
3006 self.cumulative_token_usage
3007 }
3008
3009 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3010 let Some(model) = self.configured_model.as_ref() else {
3011 return TotalTokenUsage::default();
3012 };
3013
3014 let max = model.model.max_token_count();
3015
3016 let index = self
3017 .messages
3018 .iter()
3019 .position(|msg| msg.id == message_id)
3020 .unwrap_or(0);
3021
3022 if index == 0 {
3023 return TotalTokenUsage { total: 0, max };
3024 }
3025
3026 let token_usage = &self
3027 .request_token_usage
3028 .get(index - 1)
3029 .cloned()
3030 .unwrap_or_default();
3031
3032 TotalTokenUsage {
3033 total: token_usage.total_tokens(),
3034 max,
3035 }
3036 }
3037
3038 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3039 let model = self.configured_model.as_ref()?;
3040
3041 let max = model.model.max_token_count();
3042
3043 if let Some(exceeded_error) = &self.exceeded_window_error {
3044 if model.model.id() == exceeded_error.model_id {
3045 return Some(TotalTokenUsage {
3046 total: exceeded_error.token_count,
3047 max,
3048 });
3049 }
3050 }
3051
3052 let total = self
3053 .token_usage_at_last_message()
3054 .unwrap_or_default()
3055 .total_tokens();
3056
3057 Some(TotalTokenUsage { total, max })
3058 }
3059
3060 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3061 self.request_token_usage
3062 .get(self.messages.len().saturating_sub(1))
3063 .or_else(|| self.request_token_usage.last())
3064 .cloned()
3065 }
3066
3067 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3068 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3069 self.request_token_usage
3070 .resize(self.messages.len(), placeholder);
3071
3072 if let Some(last) = self.request_token_usage.last_mut() {
3073 *last = token_usage;
3074 }
3075 }
3076
3077 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3078 self.project.update(cx, |project, cx| {
3079 project.user_store().update(cx, |user_store, cx| {
3080 user_store.update_model_request_usage(
3081 ModelRequestUsage(RequestUsage {
3082 amount: amount as i32,
3083 limit,
3084 }),
3085 cx,
3086 )
3087 })
3088 });
3089 }
3090
3091 pub fn deny_tool_use(
3092 &mut self,
3093 tool_use_id: LanguageModelToolUseId,
3094 tool_name: Arc<str>,
3095 window: Option<AnyWindowHandle>,
3096 cx: &mut Context<Self>,
3097 ) {
3098 let err = Err(anyhow::anyhow!(
3099 "Permission to run tool action denied by user"
3100 ));
3101
3102 self.tool_use.insert_tool_output(
3103 tool_use_id.clone(),
3104 tool_name,
3105 err,
3106 self.configured_model.as_ref(),
3107 );
3108 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3109 }
3110}
3111
3112#[derive(Debug, Clone, Error)]
3113pub enum ThreadError {
3114 #[error("Payment required")]
3115 PaymentRequired,
3116 #[error("Model request limit reached")]
3117 ModelRequestLimitReached { plan: Plan },
3118 #[error("Message {header}: {message}")]
3119 Message {
3120 header: SharedString,
3121 message: SharedString,
3122 },
3123}
3124
3125#[derive(Debug, Clone)]
3126pub enum ThreadEvent {
3127 ShowError(ThreadError),
3128 StreamedCompletion,
3129 ReceivedTextChunk,
3130 NewRequest,
3131 StreamedAssistantText(MessageId, String),
3132 StreamedAssistantThinking(MessageId, String),
3133 StreamedToolUse {
3134 tool_use_id: LanguageModelToolUseId,
3135 ui_text: Arc<str>,
3136 input: serde_json::Value,
3137 },
3138 MissingToolUse {
3139 tool_use_id: LanguageModelToolUseId,
3140 ui_text: Arc<str>,
3141 },
3142 InvalidToolInput {
3143 tool_use_id: LanguageModelToolUseId,
3144 ui_text: Arc<str>,
3145 invalid_input_json: Arc<str>,
3146 },
3147 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3148 MessageAdded(MessageId),
3149 MessageEdited(MessageId),
3150 MessageDeleted(MessageId),
3151 SummaryGenerated,
3152 SummaryChanged,
3153 UsePendingTools {
3154 tool_uses: Vec<PendingToolUse>,
3155 },
3156 ToolFinished {
3157 #[allow(unused)]
3158 tool_use_id: LanguageModelToolUseId,
3159 /// The pending tool use that corresponds to this tool.
3160 pending_tool_use: Option<PendingToolUse>,
3161 },
3162 CheckpointChanged,
3163 ToolConfirmationNeeded,
3164 ToolUseLimitReached,
3165 CancelEditing,
3166 CompletionCanceled,
3167 ProfileChanged,
3168 RetriesFailed {
3169 message: SharedString,
3170 },
3171}
3172
3173impl EventEmitter<ThreadEvent> for Thread {}
3174
3175struct PendingCompletion {
3176 id: usize,
3177 queue_state: QueueState,
3178 _task: Task<()>,
3179}
3180
3181/// Resolves tool name conflicts by ensuring all tool names are unique.
3182///
3183/// When multiple tools have the same name, this function applies the following rules:
3184/// 1. Native tools always keep their original name
3185/// 2. Context server tools get prefixed with their server ID and an underscore
3186/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
3187/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
3188///
3189/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
3190fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
3191 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
3192 let mut tool_name = tool.name();
3193 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3194 tool_name
3195 }
3196
3197 const MAX_TOOL_NAME_LENGTH: usize = 64;
3198
3199 let mut duplicated_tool_names = HashSet::default();
3200 let mut seen_tool_names = HashSet::default();
3201 for tool in tools {
3202 let tool_name = resolve_tool_name(tool);
3203 if seen_tool_names.contains(&tool_name) {
3204 debug_assert!(
3205 tool.source() != assistant_tool::ToolSource::Native,
3206 "There are two built-in tools with the same name: {}",
3207 tool_name
3208 );
3209 duplicated_tool_names.insert(tool_name);
3210 } else {
3211 seen_tool_names.insert(tool_name);
3212 }
3213 }
3214
3215 if duplicated_tool_names.is_empty() {
3216 return tools
3217 .into_iter()
3218 .map(|tool| (resolve_tool_name(tool), tool.clone()))
3219 .collect();
3220 }
3221
3222 tools
3223 .into_iter()
3224 .filter_map(|tool| {
3225 let mut tool_name = resolve_tool_name(tool);
3226 if !duplicated_tool_names.contains(&tool_name) {
3227 return Some((tool_name, tool.clone()));
3228 }
3229 match tool.source() {
3230 assistant_tool::ToolSource::Native => {
3231 // Built-in tools always keep their original name
3232 Some((tool_name, tool.clone()))
3233 }
3234 assistant_tool::ToolSource::ContextServer { id } => {
3235 // Context server tools are prefixed with the context server ID, and truncated if necessary
3236 tool_name.insert(0, '_');
3237 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
3238 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
3239 let mut id = id.to_string();
3240 id.truncate(len);
3241 tool_name.insert_str(0, &id);
3242 } else {
3243 tool_name.insert_str(0, &id);
3244 }
3245
3246 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3247
3248 if seen_tool_names.contains(&tool_name) {
3249 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
3250 None
3251 } else {
3252 Some((tool_name, tool.clone()))
3253 }
3254 }
3255 }
3256 })
3257 .collect()
3258}
3259
3260#[cfg(test)]
3261mod tests {
3262 use super::*;
3263 use crate::{
3264 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3265 };
3266
3267 // Test-specific constants
3268 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3269 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3270 use assistant_tool::ToolRegistry;
3271 use futures::StreamExt;
3272 use futures::future::BoxFuture;
3273 use futures::stream::BoxStream;
3274 use gpui::TestAppContext;
3275 use icons::IconName;
3276 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3277 use language_model::{
3278 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3279 LanguageModelProviderName, LanguageModelToolChoice,
3280 };
3281 use parking_lot::Mutex;
3282 use project::{FakeFs, Project};
3283 use prompt_store::PromptBuilder;
3284 use serde_json::json;
3285 use settings::{Settings, SettingsStore};
3286 use std::sync::Arc;
3287 use std::time::Duration;
3288 use theme::ThemeSettings;
3289 use util::path;
3290 use workspace::Workspace;
3291
3292 #[gpui::test]
3293 async fn test_message_with_context(cx: &mut TestAppContext) {
3294 init_test_settings(cx);
3295
3296 let project = create_test_project(
3297 cx,
3298 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3299 )
3300 .await;
3301
3302 let (_workspace, _thread_store, thread, context_store, model) =
3303 setup_test_environment(cx, project.clone()).await;
3304
3305 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3306 .await
3307 .unwrap();
3308
3309 let context =
3310 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3311 let loaded_context = cx
3312 .update(|cx| load_context(vec![context], &project, &None, cx))
3313 .await;
3314
3315 // Insert user message with context
3316 let message_id = thread.update(cx, |thread, cx| {
3317 thread.insert_user_message(
3318 "Please explain this code",
3319 loaded_context,
3320 None,
3321 Vec::new(),
3322 cx,
3323 )
3324 });
3325
3326 // Check content and context in message object
3327 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3328
3329 // Use different path format strings based on platform for the test
3330 #[cfg(windows)]
3331 let path_part = r"test\code.rs";
3332 #[cfg(not(windows))]
3333 let path_part = "test/code.rs";
3334
3335 let expected_context = format!(
3336 r#"
3337<context>
3338The following items were attached by the user. They are up-to-date and don't need to be re-read.
3339
3340<files>
3341```rs {path_part}
3342fn main() {{
3343 println!("Hello, world!");
3344}}
3345```
3346</files>
3347</context>
3348"#
3349 );
3350
3351 assert_eq!(message.role, Role::User);
3352 assert_eq!(message.segments.len(), 1);
3353 assert_eq!(
3354 message.segments[0],
3355 MessageSegment::Text("Please explain this code".to_string())
3356 );
3357 assert_eq!(message.loaded_context.text, expected_context);
3358
3359 // Check message in request
3360 let request = thread.update(cx, |thread, cx| {
3361 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3362 });
3363
3364 assert_eq!(request.messages.len(), 2);
3365 let expected_full_message = format!("{}Please explain this code", expected_context);
3366 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3367 }
3368
3369 #[gpui::test]
3370 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3371 init_test_settings(cx);
3372
3373 let project = create_test_project(
3374 cx,
3375 json!({
3376 "file1.rs": "fn function1() {}\n",
3377 "file2.rs": "fn function2() {}\n",
3378 "file3.rs": "fn function3() {}\n",
3379 "file4.rs": "fn function4() {}\n",
3380 }),
3381 )
3382 .await;
3383
3384 let (_, _thread_store, thread, context_store, model) =
3385 setup_test_environment(cx, project.clone()).await;
3386
3387 // First message with context 1
3388 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3389 .await
3390 .unwrap();
3391 let new_contexts = context_store.update(cx, |store, cx| {
3392 store.new_context_for_thread(thread.read(cx), None)
3393 });
3394 assert_eq!(new_contexts.len(), 1);
3395 let loaded_context = cx
3396 .update(|cx| load_context(new_contexts, &project, &None, cx))
3397 .await;
3398 let message1_id = thread.update(cx, |thread, cx| {
3399 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3400 });
3401
3402 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3403 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3404 .await
3405 .unwrap();
3406 let new_contexts = context_store.update(cx, |store, cx| {
3407 store.new_context_for_thread(thread.read(cx), None)
3408 });
3409 assert_eq!(new_contexts.len(), 1);
3410 let loaded_context = cx
3411 .update(|cx| load_context(new_contexts, &project, &None, cx))
3412 .await;
3413 let message2_id = thread.update(cx, |thread, cx| {
3414 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3415 });
3416
3417 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3418 //
3419 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3420 .await
3421 .unwrap();
3422 let new_contexts = context_store.update(cx, |store, cx| {
3423 store.new_context_for_thread(thread.read(cx), None)
3424 });
3425 assert_eq!(new_contexts.len(), 1);
3426 let loaded_context = cx
3427 .update(|cx| load_context(new_contexts, &project, &None, cx))
3428 .await;
3429 let message3_id = thread.update(cx, |thread, cx| {
3430 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3431 });
3432
3433 // Check what contexts are included in each message
3434 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3435 (
3436 thread.message(message1_id).unwrap().clone(),
3437 thread.message(message2_id).unwrap().clone(),
3438 thread.message(message3_id).unwrap().clone(),
3439 )
3440 });
3441
3442 // First message should include context 1
3443 assert!(message1.loaded_context.text.contains("file1.rs"));
3444
3445 // Second message should include only context 2 (not 1)
3446 assert!(!message2.loaded_context.text.contains("file1.rs"));
3447 assert!(message2.loaded_context.text.contains("file2.rs"));
3448
3449 // Third message should include only context 3 (not 1 or 2)
3450 assert!(!message3.loaded_context.text.contains("file1.rs"));
3451 assert!(!message3.loaded_context.text.contains("file2.rs"));
3452 assert!(message3.loaded_context.text.contains("file3.rs"));
3453
3454 // Check entire request to make sure all contexts are properly included
3455 let request = thread.update(cx, |thread, cx| {
3456 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3457 });
3458
3459 // The request should contain all 3 messages
3460 assert_eq!(request.messages.len(), 4);
3461
3462 // Check that the contexts are properly formatted in each message
3463 assert!(request.messages[1].string_contents().contains("file1.rs"));
3464 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3465 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3466
3467 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3468 assert!(request.messages[2].string_contents().contains("file2.rs"));
3469 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3470
3471 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3472 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3473 assert!(request.messages[3].string_contents().contains("file3.rs"));
3474
3475 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3476 .await
3477 .unwrap();
3478 let new_contexts = context_store.update(cx, |store, cx| {
3479 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3480 });
3481 assert_eq!(new_contexts.len(), 3);
3482 let loaded_context = cx
3483 .update(|cx| load_context(new_contexts, &project, &None, cx))
3484 .await
3485 .loaded_context;
3486
3487 assert!(!loaded_context.text.contains("file1.rs"));
3488 assert!(loaded_context.text.contains("file2.rs"));
3489 assert!(loaded_context.text.contains("file3.rs"));
3490 assert!(loaded_context.text.contains("file4.rs"));
3491
3492 let new_contexts = context_store.update(cx, |store, cx| {
3493 // Remove file4.rs
3494 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3495 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3496 });
3497 assert_eq!(new_contexts.len(), 2);
3498 let loaded_context = cx
3499 .update(|cx| load_context(new_contexts, &project, &None, cx))
3500 .await
3501 .loaded_context;
3502
3503 assert!(!loaded_context.text.contains("file1.rs"));
3504 assert!(loaded_context.text.contains("file2.rs"));
3505 assert!(loaded_context.text.contains("file3.rs"));
3506 assert!(!loaded_context.text.contains("file4.rs"));
3507
3508 let new_contexts = context_store.update(cx, |store, cx| {
3509 // Remove file3.rs
3510 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3511 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3512 });
3513 assert_eq!(new_contexts.len(), 1);
3514 let loaded_context = cx
3515 .update(|cx| load_context(new_contexts, &project, &None, cx))
3516 .await
3517 .loaded_context;
3518
3519 assert!(!loaded_context.text.contains("file1.rs"));
3520 assert!(loaded_context.text.contains("file2.rs"));
3521 assert!(!loaded_context.text.contains("file3.rs"));
3522 assert!(!loaded_context.text.contains("file4.rs"));
3523 }
3524
3525 #[gpui::test]
3526 async fn test_message_without_files(cx: &mut TestAppContext) {
3527 init_test_settings(cx);
3528
3529 let project = create_test_project(
3530 cx,
3531 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3532 )
3533 .await;
3534
3535 let (_, _thread_store, thread, _context_store, model) =
3536 setup_test_environment(cx, project.clone()).await;
3537
3538 // Insert user message without any context (empty context vector)
3539 let message_id = thread.update(cx, |thread, cx| {
3540 thread.insert_user_message(
3541 "What is the best way to learn Rust?",
3542 ContextLoadResult::default(),
3543 None,
3544 Vec::new(),
3545 cx,
3546 )
3547 });
3548
3549 // Check content and context in message object
3550 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3551
3552 // Context should be empty when no files are included
3553 assert_eq!(message.role, Role::User);
3554 assert_eq!(message.segments.len(), 1);
3555 assert_eq!(
3556 message.segments[0],
3557 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3558 );
3559 assert_eq!(message.loaded_context.text, "");
3560
3561 // Check message in request
3562 let request = thread.update(cx, |thread, cx| {
3563 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3564 });
3565
3566 assert_eq!(request.messages.len(), 2);
3567 assert_eq!(
3568 request.messages[1].string_contents(),
3569 "What is the best way to learn Rust?"
3570 );
3571
3572 // Add second message, also without context
3573 let message2_id = thread.update(cx, |thread, cx| {
3574 thread.insert_user_message(
3575 "Are there any good books?",
3576 ContextLoadResult::default(),
3577 None,
3578 Vec::new(),
3579 cx,
3580 )
3581 });
3582
3583 let message2 =
3584 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3585 assert_eq!(message2.loaded_context.text, "");
3586
3587 // Check that both messages appear in the request
3588 let request = thread.update(cx, |thread, cx| {
3589 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3590 });
3591
3592 assert_eq!(request.messages.len(), 3);
3593 assert_eq!(
3594 request.messages[1].string_contents(),
3595 "What is the best way to learn Rust?"
3596 );
3597 assert_eq!(
3598 request.messages[2].string_contents(),
3599 "Are there any good books?"
3600 );
3601 }
3602
3603 #[gpui::test]
3604 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3605 init_test_settings(cx);
3606
3607 let project = create_test_project(
3608 cx,
3609 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3610 )
3611 .await;
3612
3613 let (_workspace, thread_store, thread, _context_store, _model) =
3614 setup_test_environment(cx, project.clone()).await;
3615
3616 // Check that we are starting with the default profile
3617 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3618 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3619 assert_eq!(
3620 profile,
3621 AgentProfile::new(AgentProfileId::default(), tool_set)
3622 );
3623 }
3624
3625 #[gpui::test]
3626 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3627 init_test_settings(cx);
3628
3629 let project = create_test_project(
3630 cx,
3631 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3632 )
3633 .await;
3634
3635 let (_workspace, thread_store, thread, _context_store, _model) =
3636 setup_test_environment(cx, project.clone()).await;
3637
3638 // Profile gets serialized with default values
3639 let serialized = thread
3640 .update(cx, |thread, cx| thread.serialize(cx))
3641 .await
3642 .unwrap();
3643
3644 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3645
3646 let deserialized = cx.update(|cx| {
3647 thread.update(cx, |thread, cx| {
3648 Thread::deserialize(
3649 thread.id.clone(),
3650 serialized,
3651 thread.project.clone(),
3652 thread.tools.clone(),
3653 thread.prompt_builder.clone(),
3654 thread.project_context.clone(),
3655 None,
3656 cx,
3657 )
3658 })
3659 });
3660 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3661
3662 assert_eq!(
3663 deserialized.profile,
3664 AgentProfile::new(AgentProfileId::default(), tool_set)
3665 );
3666 }
3667
3668 #[gpui::test]
3669 async fn test_temperature_setting(cx: &mut TestAppContext) {
3670 init_test_settings(cx);
3671
3672 let project = create_test_project(
3673 cx,
3674 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3675 )
3676 .await;
3677
3678 let (_workspace, _thread_store, thread, _context_store, model) =
3679 setup_test_environment(cx, project.clone()).await;
3680
3681 // Both model and provider
3682 cx.update(|cx| {
3683 AgentSettings::override_global(
3684 AgentSettings {
3685 model_parameters: vec![LanguageModelParameters {
3686 provider: Some(model.provider_id().0.to_string().into()),
3687 model: Some(model.id().0.clone()),
3688 temperature: Some(0.66),
3689 }],
3690 ..AgentSettings::get_global(cx).clone()
3691 },
3692 cx,
3693 );
3694 });
3695
3696 let request = thread.update(cx, |thread, cx| {
3697 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3698 });
3699 assert_eq!(request.temperature, Some(0.66));
3700
3701 // Only model
3702 cx.update(|cx| {
3703 AgentSettings::override_global(
3704 AgentSettings {
3705 model_parameters: vec![LanguageModelParameters {
3706 provider: None,
3707 model: Some(model.id().0.clone()),
3708 temperature: Some(0.66),
3709 }],
3710 ..AgentSettings::get_global(cx).clone()
3711 },
3712 cx,
3713 );
3714 });
3715
3716 let request = thread.update(cx, |thread, cx| {
3717 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3718 });
3719 assert_eq!(request.temperature, Some(0.66));
3720
3721 // Only provider
3722 cx.update(|cx| {
3723 AgentSettings::override_global(
3724 AgentSettings {
3725 model_parameters: vec![LanguageModelParameters {
3726 provider: Some(model.provider_id().0.to_string().into()),
3727 model: None,
3728 temperature: Some(0.66),
3729 }],
3730 ..AgentSettings::get_global(cx).clone()
3731 },
3732 cx,
3733 );
3734 });
3735
3736 let request = thread.update(cx, |thread, cx| {
3737 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3738 });
3739 assert_eq!(request.temperature, Some(0.66));
3740
3741 // Same model name, different provider
3742 cx.update(|cx| {
3743 AgentSettings::override_global(
3744 AgentSettings {
3745 model_parameters: vec![LanguageModelParameters {
3746 provider: Some("anthropic".into()),
3747 model: Some(model.id().0.clone()),
3748 temperature: Some(0.66),
3749 }],
3750 ..AgentSettings::get_global(cx).clone()
3751 },
3752 cx,
3753 );
3754 });
3755
3756 let request = thread.update(cx, |thread, cx| {
3757 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3758 });
3759 assert_eq!(request.temperature, None);
3760 }
3761
3762 #[gpui::test]
3763 async fn test_thread_summary(cx: &mut TestAppContext) {
3764 init_test_settings(cx);
3765
3766 let project = create_test_project(cx, json!({})).await;
3767
3768 let (_, _thread_store, thread, _context_store, model) =
3769 setup_test_environment(cx, project.clone()).await;
3770
3771 // Initial state should be pending
3772 thread.read_with(cx, |thread, _| {
3773 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3774 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3775 });
3776
3777 // Manually setting the summary should not be allowed in this state
3778 thread.update(cx, |thread, cx| {
3779 thread.set_summary("This should not work", cx);
3780 });
3781
3782 thread.read_with(cx, |thread, _| {
3783 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3784 });
3785
3786 // Send a message
3787 thread.update(cx, |thread, cx| {
3788 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3789 thread.send_to_model(
3790 model.clone(),
3791 CompletionIntent::ThreadSummarization,
3792 None,
3793 cx,
3794 );
3795 });
3796
3797 let fake_model = model.as_fake();
3798 simulate_successful_response(&fake_model, cx);
3799
3800 // Should start generating summary when there are >= 2 messages
3801 thread.read_with(cx, |thread, _| {
3802 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3803 });
3804
3805 // Should not be able to set the summary while generating
3806 thread.update(cx, |thread, cx| {
3807 thread.set_summary("This should not work either", cx);
3808 });
3809
3810 thread.read_with(cx, |thread, _| {
3811 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3812 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3813 });
3814
3815 cx.run_until_parked();
3816 fake_model.stream_last_completion_response("Brief");
3817 fake_model.stream_last_completion_response(" Introduction");
3818 fake_model.end_last_completion_stream();
3819 cx.run_until_parked();
3820
3821 // Summary should be set
3822 thread.read_with(cx, |thread, _| {
3823 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3824 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3825 });
3826
3827 // Now we should be able to set a summary
3828 thread.update(cx, |thread, cx| {
3829 thread.set_summary("Brief Intro", cx);
3830 });
3831
3832 thread.read_with(cx, |thread, _| {
3833 assert_eq!(thread.summary().or_default(), "Brief Intro");
3834 });
3835
3836 // Test setting an empty summary (should default to DEFAULT)
3837 thread.update(cx, |thread, cx| {
3838 thread.set_summary("", cx);
3839 });
3840
3841 thread.read_with(cx, |thread, _| {
3842 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3843 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3844 });
3845 }
3846
3847 #[gpui::test]
3848 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3849 init_test_settings(cx);
3850
3851 let project = create_test_project(cx, json!({})).await;
3852
3853 let (_, _thread_store, thread, _context_store, model) =
3854 setup_test_environment(cx, project.clone()).await;
3855
3856 test_summarize_error(&model, &thread, cx);
3857
3858 // Now we should be able to set a summary
3859 thread.update(cx, |thread, cx| {
3860 thread.set_summary("Brief Intro", cx);
3861 });
3862
3863 thread.read_with(cx, |thread, _| {
3864 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3865 assert_eq!(thread.summary().or_default(), "Brief Intro");
3866 });
3867 }
3868
3869 #[gpui::test]
3870 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3871 init_test_settings(cx);
3872
3873 let project = create_test_project(cx, json!({})).await;
3874
3875 let (_, _thread_store, thread, _context_store, model) =
3876 setup_test_environment(cx, project.clone()).await;
3877
3878 test_summarize_error(&model, &thread, cx);
3879
3880 // Sending another message should not trigger another summarize request
3881 thread.update(cx, |thread, cx| {
3882 thread.insert_user_message(
3883 "How are you?",
3884 ContextLoadResult::default(),
3885 None,
3886 vec![],
3887 cx,
3888 );
3889 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3890 });
3891
3892 let fake_model = model.as_fake();
3893 simulate_successful_response(&fake_model, cx);
3894
3895 thread.read_with(cx, |thread, _| {
3896 // State is still Error, not Generating
3897 assert!(matches!(thread.summary(), ThreadSummary::Error));
3898 });
3899
3900 // But the summarize request can be invoked manually
3901 thread.update(cx, |thread, cx| {
3902 thread.summarize(cx);
3903 });
3904
3905 thread.read_with(cx, |thread, _| {
3906 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3907 });
3908
3909 cx.run_until_parked();
3910 fake_model.stream_last_completion_response("A successful summary");
3911 fake_model.end_last_completion_stream();
3912 cx.run_until_parked();
3913
3914 thread.read_with(cx, |thread, _| {
3915 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3916 assert_eq!(thread.summary().or_default(), "A successful summary");
3917 });
3918 }
3919
3920 #[gpui::test]
3921 fn test_resolve_tool_name_conflicts() {
3922 use assistant_tool::{Tool, ToolSource};
3923
3924 assert_resolve_tool_name_conflicts(
3925 vec![
3926 TestTool::new("tool1", ToolSource::Native),
3927 TestTool::new("tool2", ToolSource::Native),
3928 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3929 ],
3930 vec!["tool1", "tool2", "tool3"],
3931 );
3932
3933 assert_resolve_tool_name_conflicts(
3934 vec![
3935 TestTool::new("tool1", ToolSource::Native),
3936 TestTool::new("tool2", ToolSource::Native),
3937 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3938 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3939 ],
3940 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3941 );
3942
3943 assert_resolve_tool_name_conflicts(
3944 vec![
3945 TestTool::new("tool1", ToolSource::Native),
3946 TestTool::new("tool2", ToolSource::Native),
3947 TestTool::new("tool3", ToolSource::Native),
3948 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3949 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3950 ],
3951 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3952 );
3953
3954 // Test that tool with very long name is always truncated
3955 assert_resolve_tool_name_conflicts(
3956 vec![TestTool::new(
3957 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3958 ToolSource::Native,
3959 )],
3960 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3961 );
3962
3963 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3964 assert_resolve_tool_name_conflicts(
3965 vec![
3966 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3967 TestTool::new(
3968 "tool-with-very-very-very-long-name",
3969 ToolSource::ContextServer {
3970 id: "mcp-with-very-very-very-long-name".into(),
3971 },
3972 ),
3973 ],
3974 vec![
3975 "tool-with-very-very-very-long-name",
3976 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3977 ],
3978 );
3979
3980 fn assert_resolve_tool_name_conflicts(
3981 tools: Vec<TestTool>,
3982 expected: Vec<impl Into<String>>,
3983 ) {
3984 let tools: Vec<Arc<dyn Tool>> = tools
3985 .into_iter()
3986 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3987 .collect();
3988 let tools = resolve_tool_name_conflicts(&tools);
3989 assert_eq!(tools.len(), expected.len());
3990 for (i, expected_name) in expected.into_iter().enumerate() {
3991 let expected_name = expected_name.into();
3992 let actual_name = &tools[i].0;
3993 assert_eq!(
3994 actual_name, &expected_name,
3995 "Expected '{}' got '{}' at index {}",
3996 expected_name, actual_name, i
3997 );
3998 }
3999 }
4000
4001 struct TestTool {
4002 name: String,
4003 source: ToolSource,
4004 }
4005
4006 impl TestTool {
4007 fn new(name: impl Into<String>, source: ToolSource) -> Self {
4008 Self {
4009 name: name.into(),
4010 source,
4011 }
4012 }
4013 }
4014
4015 impl Tool for TestTool {
4016 fn name(&self) -> String {
4017 self.name.clone()
4018 }
4019
4020 fn icon(&self) -> IconName {
4021 IconName::Ai
4022 }
4023
4024 fn may_perform_edits(&self) -> bool {
4025 false
4026 }
4027
4028 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
4029 true
4030 }
4031
4032 fn source(&self) -> ToolSource {
4033 self.source.clone()
4034 }
4035
4036 fn description(&self) -> String {
4037 "Test tool".to_string()
4038 }
4039
4040 fn ui_text(&self, _input: &serde_json::Value) -> String {
4041 "Test tool".to_string()
4042 }
4043
4044 fn run(
4045 self: Arc<Self>,
4046 _input: serde_json::Value,
4047 _request: Arc<LanguageModelRequest>,
4048 _project: Entity<Project>,
4049 _action_log: Entity<ActionLog>,
4050 _model: Arc<dyn LanguageModel>,
4051 _window: Option<AnyWindowHandle>,
4052 _cx: &mut App,
4053 ) -> assistant_tool::ToolResult {
4054 assistant_tool::ToolResult {
4055 output: Task::ready(Err(anyhow::anyhow!("No content"))),
4056 card: None,
4057 }
4058 }
4059 }
4060 }
4061
4062 // Helper to create a model that returns errors
4063 enum TestError {
4064 Overloaded,
4065 InternalServerError,
4066 }
4067
4068 struct ErrorInjector {
4069 inner: Arc<FakeLanguageModel>,
4070 error_type: TestError,
4071 }
4072
4073 impl ErrorInjector {
4074 fn new(error_type: TestError) -> Self {
4075 Self {
4076 inner: Arc::new(FakeLanguageModel::default()),
4077 error_type,
4078 }
4079 }
4080 }
4081
4082 impl LanguageModel for ErrorInjector {
4083 fn id(&self) -> LanguageModelId {
4084 self.inner.id()
4085 }
4086
4087 fn name(&self) -> LanguageModelName {
4088 self.inner.name()
4089 }
4090
4091 fn provider_id(&self) -> LanguageModelProviderId {
4092 self.inner.provider_id()
4093 }
4094
4095 fn provider_name(&self) -> LanguageModelProviderName {
4096 self.inner.provider_name()
4097 }
4098
4099 fn supports_tools(&self) -> bool {
4100 self.inner.supports_tools()
4101 }
4102
4103 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4104 self.inner.supports_tool_choice(choice)
4105 }
4106
4107 fn supports_images(&self) -> bool {
4108 self.inner.supports_images()
4109 }
4110
4111 fn telemetry_id(&self) -> String {
4112 self.inner.telemetry_id()
4113 }
4114
4115 fn max_token_count(&self) -> u64 {
4116 self.inner.max_token_count()
4117 }
4118
4119 fn count_tokens(
4120 &self,
4121 request: LanguageModelRequest,
4122 cx: &App,
4123 ) -> BoxFuture<'static, Result<u64>> {
4124 self.inner.count_tokens(request, cx)
4125 }
4126
4127 fn stream_completion(
4128 &self,
4129 _request: LanguageModelRequest,
4130 _cx: &AsyncApp,
4131 ) -> BoxFuture<
4132 'static,
4133 Result<
4134 BoxStream<
4135 'static,
4136 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4137 >,
4138 LanguageModelCompletionError,
4139 >,
4140 > {
4141 let error = match self.error_type {
4142 TestError::Overloaded => LanguageModelCompletionError::Overloaded,
4143 TestError::InternalServerError => {
4144 LanguageModelCompletionError::ApiInternalServerError
4145 }
4146 };
4147 async move {
4148 let stream = futures::stream::once(async move { Err(error) });
4149 Ok(stream.boxed())
4150 }
4151 .boxed()
4152 }
4153
4154 fn as_fake(&self) -> &FakeLanguageModel {
4155 &self.inner
4156 }
4157 }
4158
4159 #[gpui::test]
4160 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4161 init_test_settings(cx);
4162
4163 let project = create_test_project(cx, json!({})).await;
4164 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4165
4166 // Create model that returns overloaded error
4167 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4168
4169 // Insert a user message
4170 thread.update(cx, |thread, cx| {
4171 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4172 });
4173
4174 // Start completion
4175 thread.update(cx, |thread, cx| {
4176 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4177 });
4178
4179 cx.run_until_parked();
4180
4181 thread.read_with(cx, |thread, _| {
4182 assert!(thread.retry_state.is_some(), "Should have retry state");
4183 let retry_state = thread.retry_state.as_ref().unwrap();
4184 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4185 assert_eq!(
4186 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4187 "Should have default max attempts"
4188 );
4189 });
4190
4191 // Check that a retry message was added
4192 thread.read_with(cx, |thread, _| {
4193 let mut messages = thread.messages();
4194 assert!(
4195 messages.any(|msg| {
4196 msg.role == Role::System
4197 && msg.ui_only
4198 && msg.segments.iter().any(|seg| {
4199 if let MessageSegment::Text(text) = seg {
4200 text.contains("overloaded")
4201 && text
4202 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4203 } else {
4204 false
4205 }
4206 })
4207 }),
4208 "Should have added a system retry message"
4209 );
4210 });
4211
4212 let retry_count = thread.update(cx, |thread, _| {
4213 thread
4214 .messages
4215 .iter()
4216 .filter(|m| {
4217 m.ui_only
4218 && m.segments.iter().any(|s| {
4219 if let MessageSegment::Text(text) = s {
4220 text.contains("Retrying") && text.contains("seconds")
4221 } else {
4222 false
4223 }
4224 })
4225 })
4226 .count()
4227 });
4228
4229 assert_eq!(retry_count, 1, "Should have one retry message");
4230 }
4231
4232 #[gpui::test]
4233 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4234 init_test_settings(cx);
4235
4236 let project = create_test_project(cx, json!({})).await;
4237 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4238
4239 // Create model that returns internal server error
4240 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4241
4242 // Insert a user message
4243 thread.update(cx, |thread, cx| {
4244 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4245 });
4246
4247 // Start completion
4248 thread.update(cx, |thread, cx| {
4249 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4250 });
4251
4252 cx.run_until_parked();
4253
4254 // Check retry state on thread
4255 thread.read_with(cx, |thread, _| {
4256 assert!(thread.retry_state.is_some(), "Should have retry state");
4257 let retry_state = thread.retry_state.as_ref().unwrap();
4258 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4259 assert_eq!(
4260 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4261 "Should have correct max attempts"
4262 );
4263 });
4264
4265 // Check that a retry message was added with provider name
4266 thread.read_with(cx, |thread, _| {
4267 let mut messages = thread.messages();
4268 assert!(
4269 messages.any(|msg| {
4270 msg.role == Role::System
4271 && msg.ui_only
4272 && msg.segments.iter().any(|seg| {
4273 if let MessageSegment::Text(text) = seg {
4274 text.contains("internal")
4275 && text.contains("Fake")
4276 && text
4277 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4278 } else {
4279 false
4280 }
4281 })
4282 }),
4283 "Should have added a system retry message with provider name"
4284 );
4285 });
4286
4287 // Count retry messages
4288 let retry_count = thread.update(cx, |thread, _| {
4289 thread
4290 .messages
4291 .iter()
4292 .filter(|m| {
4293 m.ui_only
4294 && m.segments.iter().any(|s| {
4295 if let MessageSegment::Text(text) = s {
4296 text.contains("Retrying") && text.contains("seconds")
4297 } else {
4298 false
4299 }
4300 })
4301 })
4302 .count()
4303 });
4304
4305 assert_eq!(retry_count, 1, "Should have one retry message");
4306 }
4307
4308 #[gpui::test]
4309 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4310 init_test_settings(cx);
4311
4312 let project = create_test_project(cx, json!({})).await;
4313 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4314
4315 // Create model that returns overloaded error
4316 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4317
4318 // Insert a user message
4319 thread.update(cx, |thread, cx| {
4320 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4321 });
4322
4323 // Track retry events and completion count
4324 // Track completion events
4325 let completion_count = Arc::new(Mutex::new(0));
4326 let completion_count_clone = completion_count.clone();
4327
4328 let _subscription = thread.update(cx, |_, cx| {
4329 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4330 if let ThreadEvent::NewRequest = event {
4331 *completion_count_clone.lock() += 1;
4332 }
4333 })
4334 });
4335
4336 // First attempt
4337 thread.update(cx, |thread, cx| {
4338 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4339 });
4340 cx.run_until_parked();
4341
4342 // Should have scheduled first retry - count retry messages
4343 let retry_count = thread.update(cx, |thread, _| {
4344 thread
4345 .messages
4346 .iter()
4347 .filter(|m| {
4348 m.ui_only
4349 && m.segments.iter().any(|s| {
4350 if let MessageSegment::Text(text) = s {
4351 text.contains("Retrying") && text.contains("seconds")
4352 } else {
4353 false
4354 }
4355 })
4356 })
4357 .count()
4358 });
4359 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4360
4361 // Check retry state
4362 thread.read_with(cx, |thread, _| {
4363 assert!(thread.retry_state.is_some(), "Should have retry state");
4364 let retry_state = thread.retry_state.as_ref().unwrap();
4365 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4366 });
4367
4368 // Advance clock for first retry
4369 cx.executor()
4370 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4371 cx.run_until_parked();
4372
4373 // Should have scheduled second retry - count retry messages
4374 let retry_count = thread.update(cx, |thread, _| {
4375 thread
4376 .messages
4377 .iter()
4378 .filter(|m| {
4379 m.ui_only
4380 && m.segments.iter().any(|s| {
4381 if let MessageSegment::Text(text) = s {
4382 text.contains("Retrying") && text.contains("seconds")
4383 } else {
4384 false
4385 }
4386 })
4387 })
4388 .count()
4389 });
4390 assert_eq!(retry_count, 2, "Should have scheduled second retry");
4391
4392 // Check retry state updated
4393 thread.read_with(cx, |thread, _| {
4394 assert!(thread.retry_state.is_some(), "Should have retry state");
4395 let retry_state = thread.retry_state.as_ref().unwrap();
4396 assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4397 assert_eq!(
4398 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4399 "Should have correct max attempts"
4400 );
4401 });
4402
4403 // Advance clock for second retry (exponential backoff)
4404 cx.executor()
4405 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4406 cx.run_until_parked();
4407
4408 // Should have scheduled third retry
4409 // Count all retry messages now
4410 let retry_count = thread.update(cx, |thread, _| {
4411 thread
4412 .messages
4413 .iter()
4414 .filter(|m| {
4415 m.ui_only
4416 && m.segments.iter().any(|s| {
4417 if let MessageSegment::Text(text) = s {
4418 text.contains("Retrying") && text.contains("seconds")
4419 } else {
4420 false
4421 }
4422 })
4423 })
4424 .count()
4425 });
4426 assert_eq!(
4427 retry_count, MAX_RETRY_ATTEMPTS as usize,
4428 "Should have scheduled third retry"
4429 );
4430
4431 // Check retry state updated
4432 thread.read_with(cx, |thread, _| {
4433 assert!(thread.retry_state.is_some(), "Should have retry state");
4434 let retry_state = thread.retry_state.as_ref().unwrap();
4435 assert_eq!(
4436 retry_state.attempt, MAX_RETRY_ATTEMPTS,
4437 "Should be at max retry attempt"
4438 );
4439 assert_eq!(
4440 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4441 "Should have correct max attempts"
4442 );
4443 });
4444
4445 // Advance clock for third retry (exponential backoff)
4446 cx.executor()
4447 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4448 cx.run_until_parked();
4449
4450 // No more retries should be scheduled after clock was advanced.
4451 let retry_count = thread.update(cx, |thread, _| {
4452 thread
4453 .messages
4454 .iter()
4455 .filter(|m| {
4456 m.ui_only
4457 && m.segments.iter().any(|s| {
4458 if let MessageSegment::Text(text) = s {
4459 text.contains("Retrying") && text.contains("seconds")
4460 } else {
4461 false
4462 }
4463 })
4464 })
4465 .count()
4466 });
4467 assert_eq!(
4468 retry_count, MAX_RETRY_ATTEMPTS as usize,
4469 "Should not exceed max retries"
4470 );
4471
4472 // Final completion count should be initial + max retries
4473 assert_eq!(
4474 *completion_count.lock(),
4475 (MAX_RETRY_ATTEMPTS + 1) as usize,
4476 "Should have made initial + max retry attempts"
4477 );
4478 }
4479
4480 #[gpui::test]
4481 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4482 init_test_settings(cx);
4483
4484 let project = create_test_project(cx, json!({})).await;
4485 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4486
4487 // Create model that returns overloaded error
4488 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4489
4490 // Insert a user message
4491 thread.update(cx, |thread, cx| {
4492 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4493 });
4494
4495 // Track events
4496 let retries_failed = Arc::new(Mutex::new(false));
4497 let retries_failed_clone = retries_failed.clone();
4498
4499 let _subscription = thread.update(cx, |_, cx| {
4500 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4501 if let ThreadEvent::RetriesFailed { .. } = event {
4502 *retries_failed_clone.lock() = true;
4503 }
4504 })
4505 });
4506
4507 // Start initial completion
4508 thread.update(cx, |thread, cx| {
4509 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4510 });
4511 cx.run_until_parked();
4512
4513 // Advance through all retries
4514 for i in 0..MAX_RETRY_ATTEMPTS {
4515 let delay = if i == 0 {
4516 BASE_RETRY_DELAY_SECS
4517 } else {
4518 BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4519 };
4520 cx.executor().advance_clock(Duration::from_secs(delay));
4521 cx.run_until_parked();
4522 }
4523
4524 // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4525 // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4526 let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4527 cx.executor()
4528 .advance_clock(Duration::from_secs(final_delay));
4529 cx.run_until_parked();
4530
4531 let retry_count = thread.update(cx, |thread, _| {
4532 thread
4533 .messages
4534 .iter()
4535 .filter(|m| {
4536 m.ui_only
4537 && m.segments.iter().any(|s| {
4538 if let MessageSegment::Text(text) = s {
4539 text.contains("Retrying") && text.contains("seconds")
4540 } else {
4541 false
4542 }
4543 })
4544 })
4545 .count()
4546 });
4547
4548 // After max retries, should emit RetriesFailed event
4549 assert_eq!(
4550 retry_count, MAX_RETRY_ATTEMPTS as usize,
4551 "Should have attempted max retries"
4552 );
4553 assert!(
4554 *retries_failed.lock(),
4555 "Should emit RetriesFailed event after max retries exceeded"
4556 );
4557
4558 // Retry state should be cleared
4559 thread.read_with(cx, |thread, _| {
4560 assert!(
4561 thread.retry_state.is_none(),
4562 "Retry state should be cleared after max retries"
4563 );
4564
4565 // Verify we have the expected number of retry messages
4566 let retry_messages = thread
4567 .messages
4568 .iter()
4569 .filter(|msg| msg.ui_only && msg.role == Role::System)
4570 .count();
4571 assert_eq!(
4572 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4573 "Should have one retry message per attempt"
4574 );
4575 });
4576 }
4577
4578 #[gpui::test]
4579 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4580 init_test_settings(cx);
4581
4582 let project = create_test_project(cx, json!({})).await;
4583 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4584
4585 // We'll use a wrapper to switch behavior after first failure
4586 struct RetryTestModel {
4587 inner: Arc<FakeLanguageModel>,
4588 failed_once: Arc<Mutex<bool>>,
4589 }
4590
4591 impl LanguageModel for RetryTestModel {
4592 fn id(&self) -> LanguageModelId {
4593 self.inner.id()
4594 }
4595
4596 fn name(&self) -> LanguageModelName {
4597 self.inner.name()
4598 }
4599
4600 fn provider_id(&self) -> LanguageModelProviderId {
4601 self.inner.provider_id()
4602 }
4603
4604 fn provider_name(&self) -> LanguageModelProviderName {
4605 self.inner.provider_name()
4606 }
4607
4608 fn supports_tools(&self) -> bool {
4609 self.inner.supports_tools()
4610 }
4611
4612 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4613 self.inner.supports_tool_choice(choice)
4614 }
4615
4616 fn supports_images(&self) -> bool {
4617 self.inner.supports_images()
4618 }
4619
4620 fn telemetry_id(&self) -> String {
4621 self.inner.telemetry_id()
4622 }
4623
4624 fn max_token_count(&self) -> u64 {
4625 self.inner.max_token_count()
4626 }
4627
4628 fn count_tokens(
4629 &self,
4630 request: LanguageModelRequest,
4631 cx: &App,
4632 ) -> BoxFuture<'static, Result<u64>> {
4633 self.inner.count_tokens(request, cx)
4634 }
4635
4636 fn stream_completion(
4637 &self,
4638 request: LanguageModelRequest,
4639 cx: &AsyncApp,
4640 ) -> BoxFuture<
4641 'static,
4642 Result<
4643 BoxStream<
4644 'static,
4645 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4646 >,
4647 LanguageModelCompletionError,
4648 >,
4649 > {
4650 if !*self.failed_once.lock() {
4651 *self.failed_once.lock() = true;
4652 // Return error on first attempt
4653 let stream = futures::stream::once(async move {
4654 Err(LanguageModelCompletionError::Overloaded)
4655 });
4656 async move { Ok(stream.boxed()) }.boxed()
4657 } else {
4658 // Succeed on retry
4659 self.inner.stream_completion(request, cx)
4660 }
4661 }
4662
4663 fn as_fake(&self) -> &FakeLanguageModel {
4664 &self.inner
4665 }
4666 }
4667
4668 let model = Arc::new(RetryTestModel {
4669 inner: Arc::new(FakeLanguageModel::default()),
4670 failed_once: Arc::new(Mutex::new(false)),
4671 });
4672
4673 // Insert a user message
4674 thread.update(cx, |thread, cx| {
4675 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4676 });
4677
4678 // Track message deletions
4679 // Track when retry completes successfully
4680 let retry_completed = Arc::new(Mutex::new(false));
4681 let retry_completed_clone = retry_completed.clone();
4682
4683 let _subscription = thread.update(cx, |_, cx| {
4684 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4685 if let ThreadEvent::StreamedCompletion = event {
4686 *retry_completed_clone.lock() = true;
4687 }
4688 })
4689 });
4690
4691 // Start completion
4692 thread.update(cx, |thread, cx| {
4693 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4694 });
4695 cx.run_until_parked();
4696
4697 // Get the retry message ID
4698 let retry_message_id = thread.read_with(cx, |thread, _| {
4699 thread
4700 .messages()
4701 .find(|msg| msg.role == Role::System && msg.ui_only)
4702 .map(|msg| msg.id)
4703 .expect("Should have a retry message")
4704 });
4705
4706 // Wait for retry
4707 cx.executor()
4708 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4709 cx.run_until_parked();
4710
4711 // Stream some successful content
4712 let fake_model = model.as_fake();
4713 // After the retry, there should be a new pending completion
4714 let pending = fake_model.pending_completions();
4715 assert!(
4716 !pending.is_empty(),
4717 "Should have a pending completion after retry"
4718 );
4719 fake_model.stream_completion_response(&pending[0], "Success!");
4720 fake_model.end_completion_stream(&pending[0]);
4721 cx.run_until_parked();
4722
4723 // Check that the retry completed successfully
4724 assert!(
4725 *retry_completed.lock(),
4726 "Retry should have completed successfully"
4727 );
4728
4729 // Retry message should still exist but be marked as ui_only
4730 thread.read_with(cx, |thread, _| {
4731 let retry_msg = thread
4732 .message(retry_message_id)
4733 .expect("Retry message should still exist");
4734 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4735 assert_eq!(
4736 retry_msg.role,
4737 Role::System,
4738 "Retry message should have System role"
4739 );
4740 });
4741 }
4742
4743 #[gpui::test]
4744 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4745 init_test_settings(cx);
4746
4747 let project = create_test_project(cx, json!({})).await;
4748 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4749
4750 // Create a model that fails once then succeeds
4751 struct FailOnceModel {
4752 inner: Arc<FakeLanguageModel>,
4753 failed_once: Arc<Mutex<bool>>,
4754 }
4755
4756 impl LanguageModel for FailOnceModel {
4757 fn id(&self) -> LanguageModelId {
4758 self.inner.id()
4759 }
4760
4761 fn name(&self) -> LanguageModelName {
4762 self.inner.name()
4763 }
4764
4765 fn provider_id(&self) -> LanguageModelProviderId {
4766 self.inner.provider_id()
4767 }
4768
4769 fn provider_name(&self) -> LanguageModelProviderName {
4770 self.inner.provider_name()
4771 }
4772
4773 fn supports_tools(&self) -> bool {
4774 self.inner.supports_tools()
4775 }
4776
4777 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4778 self.inner.supports_tool_choice(choice)
4779 }
4780
4781 fn supports_images(&self) -> bool {
4782 self.inner.supports_images()
4783 }
4784
4785 fn telemetry_id(&self) -> String {
4786 self.inner.telemetry_id()
4787 }
4788
4789 fn max_token_count(&self) -> u64 {
4790 self.inner.max_token_count()
4791 }
4792
4793 fn count_tokens(
4794 &self,
4795 request: LanguageModelRequest,
4796 cx: &App,
4797 ) -> BoxFuture<'static, Result<u64>> {
4798 self.inner.count_tokens(request, cx)
4799 }
4800
4801 fn stream_completion(
4802 &self,
4803 request: LanguageModelRequest,
4804 cx: &AsyncApp,
4805 ) -> BoxFuture<
4806 'static,
4807 Result<
4808 BoxStream<
4809 'static,
4810 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4811 >,
4812 LanguageModelCompletionError,
4813 >,
4814 > {
4815 if !*self.failed_once.lock() {
4816 *self.failed_once.lock() = true;
4817 // Return error on first attempt
4818 let stream = futures::stream::once(async move {
4819 Err(LanguageModelCompletionError::Overloaded)
4820 });
4821 async move { Ok(stream.boxed()) }.boxed()
4822 } else {
4823 // Succeed on retry
4824 self.inner.stream_completion(request, cx)
4825 }
4826 }
4827 }
4828
4829 let fail_once_model = Arc::new(FailOnceModel {
4830 inner: Arc::new(FakeLanguageModel::default()),
4831 failed_once: Arc::new(Mutex::new(false)),
4832 });
4833
4834 // Insert a user message
4835 thread.update(cx, |thread, cx| {
4836 thread.insert_user_message(
4837 "Test message",
4838 ContextLoadResult::default(),
4839 None,
4840 vec![],
4841 cx,
4842 );
4843 });
4844
4845 // Start completion with fail-once model
4846 thread.update(cx, |thread, cx| {
4847 thread.send_to_model(
4848 fail_once_model.clone(),
4849 CompletionIntent::UserPrompt,
4850 None,
4851 cx,
4852 );
4853 });
4854
4855 cx.run_until_parked();
4856
4857 // Verify retry state exists after first failure
4858 thread.read_with(cx, |thread, _| {
4859 assert!(
4860 thread.retry_state.is_some(),
4861 "Should have retry state after failure"
4862 );
4863 });
4864
4865 // Wait for retry delay
4866 cx.executor()
4867 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4868 cx.run_until_parked();
4869
4870 // The retry should now use our FailOnceModel which should succeed
4871 // We need to help the FakeLanguageModel complete the stream
4872 let inner_fake = fail_once_model.inner.clone();
4873
4874 // Wait a bit for the retry to start
4875 cx.run_until_parked();
4876
4877 // Check for pending completions and complete them
4878 if let Some(pending) = inner_fake.pending_completions().first() {
4879 inner_fake.stream_completion_response(pending, "Success!");
4880 inner_fake.end_completion_stream(pending);
4881 }
4882 cx.run_until_parked();
4883
4884 thread.read_with(cx, |thread, _| {
4885 assert!(
4886 thread.retry_state.is_none(),
4887 "Retry state should be cleared after successful completion"
4888 );
4889
4890 let has_assistant_message = thread
4891 .messages
4892 .iter()
4893 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4894 assert!(
4895 has_assistant_message,
4896 "Should have an assistant message after successful retry"
4897 );
4898 });
4899 }
4900
4901 #[gpui::test]
4902 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4903 init_test_settings(cx);
4904
4905 let project = create_test_project(cx, json!({})).await;
4906 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4907
4908 // Create a model that returns rate limit error with retry_after
4909 struct RateLimitModel {
4910 inner: Arc<FakeLanguageModel>,
4911 }
4912
4913 impl LanguageModel for RateLimitModel {
4914 fn id(&self) -> LanguageModelId {
4915 self.inner.id()
4916 }
4917
4918 fn name(&self) -> LanguageModelName {
4919 self.inner.name()
4920 }
4921
4922 fn provider_id(&self) -> LanguageModelProviderId {
4923 self.inner.provider_id()
4924 }
4925
4926 fn provider_name(&self) -> LanguageModelProviderName {
4927 self.inner.provider_name()
4928 }
4929
4930 fn supports_tools(&self) -> bool {
4931 self.inner.supports_tools()
4932 }
4933
4934 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4935 self.inner.supports_tool_choice(choice)
4936 }
4937
4938 fn supports_images(&self) -> bool {
4939 self.inner.supports_images()
4940 }
4941
4942 fn telemetry_id(&self) -> String {
4943 self.inner.telemetry_id()
4944 }
4945
4946 fn max_token_count(&self) -> u64 {
4947 self.inner.max_token_count()
4948 }
4949
4950 fn count_tokens(
4951 &self,
4952 request: LanguageModelRequest,
4953 cx: &App,
4954 ) -> BoxFuture<'static, Result<u64>> {
4955 self.inner.count_tokens(request, cx)
4956 }
4957
4958 fn stream_completion(
4959 &self,
4960 _request: LanguageModelRequest,
4961 _cx: &AsyncApp,
4962 ) -> BoxFuture<
4963 'static,
4964 Result<
4965 BoxStream<
4966 'static,
4967 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4968 >,
4969 LanguageModelCompletionError,
4970 >,
4971 > {
4972 async move {
4973 let stream = futures::stream::once(async move {
4974 Err(LanguageModelCompletionError::RateLimitExceeded {
4975 retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
4976 })
4977 });
4978 Ok(stream.boxed())
4979 }
4980 .boxed()
4981 }
4982
4983 fn as_fake(&self) -> &FakeLanguageModel {
4984 &self.inner
4985 }
4986 }
4987
4988 let model = Arc::new(RateLimitModel {
4989 inner: Arc::new(FakeLanguageModel::default()),
4990 });
4991
4992 // Insert a user message
4993 thread.update(cx, |thread, cx| {
4994 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4995 });
4996
4997 // Start completion
4998 thread.update(cx, |thread, cx| {
4999 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5000 });
5001
5002 cx.run_until_parked();
5003
5004 let retry_count = thread.update(cx, |thread, _| {
5005 thread
5006 .messages
5007 .iter()
5008 .filter(|m| {
5009 m.ui_only
5010 && m.segments.iter().any(|s| {
5011 if let MessageSegment::Text(text) = s {
5012 text.contains("rate limit exceeded")
5013 } else {
5014 false
5015 }
5016 })
5017 })
5018 .count()
5019 });
5020 assert_eq!(retry_count, 1, "Should have scheduled one retry");
5021
5022 thread.read_with(cx, |thread, _| {
5023 assert!(
5024 thread.retry_state.is_none(),
5025 "Rate limit errors should not set retry_state"
5026 );
5027 });
5028
5029 // Verify we have one retry message
5030 thread.read_with(cx, |thread, _| {
5031 let retry_messages = thread
5032 .messages
5033 .iter()
5034 .filter(|msg| {
5035 msg.ui_only
5036 && msg.segments.iter().any(|seg| {
5037 if let MessageSegment::Text(text) = seg {
5038 text.contains("rate limit exceeded")
5039 } else {
5040 false
5041 }
5042 })
5043 })
5044 .count();
5045 assert_eq!(
5046 retry_messages, 1,
5047 "Should have one rate limit retry message"
5048 );
5049 });
5050
5051 // Check that retry message doesn't include attempt count
5052 thread.read_with(cx, |thread, _| {
5053 let retry_message = thread
5054 .messages
5055 .iter()
5056 .find(|msg| msg.role == Role::System && msg.ui_only)
5057 .expect("Should have a retry message");
5058
5059 // Check that the message doesn't contain attempt count
5060 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5061 assert!(
5062 !text.contains("attempt"),
5063 "Rate limit retry message should not contain attempt count"
5064 );
5065 assert!(
5066 text.contains(&format!(
5067 "Retrying in {} seconds",
5068 TEST_RATE_LIMIT_RETRY_SECS
5069 )),
5070 "Rate limit retry message should contain retry delay"
5071 );
5072 }
5073 });
5074 }
5075
5076 #[gpui::test]
5077 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5078 init_test_settings(cx);
5079
5080 let project = create_test_project(cx, json!({})).await;
5081 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5082
5083 // Insert a regular user message
5084 thread.update(cx, |thread, cx| {
5085 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5086 });
5087
5088 // Insert a UI-only message (like our retry notifications)
5089 thread.update(cx, |thread, cx| {
5090 let id = thread.next_message_id.post_inc();
5091 thread.messages.push(Message {
5092 id,
5093 role: Role::System,
5094 segments: vec![MessageSegment::Text(
5095 "This is a UI-only message that should not be sent to the model".to_string(),
5096 )],
5097 loaded_context: LoadedContext::default(),
5098 creases: Vec::new(),
5099 is_hidden: true,
5100 ui_only: true,
5101 });
5102 cx.emit(ThreadEvent::MessageAdded(id));
5103 });
5104
5105 // Insert another regular message
5106 thread.update(cx, |thread, cx| {
5107 thread.insert_user_message(
5108 "How are you?",
5109 ContextLoadResult::default(),
5110 None,
5111 vec![],
5112 cx,
5113 );
5114 });
5115
5116 // Generate the completion request
5117 let request = thread.update(cx, |thread, cx| {
5118 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5119 });
5120
5121 // Verify that the request only contains non-UI-only messages
5122 // Should have system prompt + 2 user messages, but not the UI-only message
5123 let user_messages: Vec<_> = request
5124 .messages
5125 .iter()
5126 .filter(|msg| msg.role == Role::User)
5127 .collect();
5128 assert_eq!(
5129 user_messages.len(),
5130 2,
5131 "Should have exactly 2 user messages"
5132 );
5133
5134 // Verify the UI-only content is not present anywhere in the request
5135 let request_text = request
5136 .messages
5137 .iter()
5138 .flat_map(|msg| &msg.content)
5139 .filter_map(|content| match content {
5140 MessageContent::Text(text) => Some(text.as_str()),
5141 _ => None,
5142 })
5143 .collect::<String>();
5144
5145 assert!(
5146 !request_text.contains("UI-only message"),
5147 "UI-only message content should not be in the request"
5148 );
5149
5150 // Verify the thread still has all 3 messages (including UI-only)
5151 thread.read_with(cx, |thread, _| {
5152 assert_eq!(
5153 thread.messages().count(),
5154 3,
5155 "Thread should have 3 messages"
5156 );
5157 assert_eq!(
5158 thread.messages().filter(|m| m.ui_only).count(),
5159 1,
5160 "Thread should have 1 UI-only message"
5161 );
5162 });
5163
5164 // Verify that UI-only messages are not serialized
5165 let serialized = thread
5166 .update(cx, |thread, cx| thread.serialize(cx))
5167 .await
5168 .unwrap();
5169 assert_eq!(
5170 serialized.messages.len(),
5171 2,
5172 "Serialized thread should only have 2 messages (no UI-only)"
5173 );
5174 }
5175
5176 #[gpui::test]
5177 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5178 init_test_settings(cx);
5179
5180 let project = create_test_project(cx, json!({})).await;
5181 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5182
5183 // Create model that returns overloaded error
5184 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5185
5186 // Insert a user message
5187 thread.update(cx, |thread, cx| {
5188 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5189 });
5190
5191 // Start completion
5192 thread.update(cx, |thread, cx| {
5193 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5194 });
5195
5196 cx.run_until_parked();
5197
5198 // Verify retry was scheduled by checking for retry message
5199 let has_retry_message = thread.read_with(cx, |thread, _| {
5200 thread.messages.iter().any(|m| {
5201 m.ui_only
5202 && m.segments.iter().any(|s| {
5203 if let MessageSegment::Text(text) = s {
5204 text.contains("Retrying") && text.contains("seconds")
5205 } else {
5206 false
5207 }
5208 })
5209 })
5210 });
5211 assert!(has_retry_message, "Should have scheduled a retry");
5212
5213 // Cancel the completion before the retry happens
5214 thread.update(cx, |thread, cx| {
5215 thread.cancel_last_completion(None, cx);
5216 });
5217
5218 cx.run_until_parked();
5219
5220 // The retry should not have happened - no pending completions
5221 let fake_model = model.as_fake();
5222 assert_eq!(
5223 fake_model.pending_completions().len(),
5224 0,
5225 "Should have no pending completions after cancellation"
5226 );
5227
5228 // Verify the retry was cancelled by checking retry state
5229 thread.read_with(cx, |thread, _| {
5230 if let Some(retry_state) = &thread.retry_state {
5231 panic!(
5232 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5233 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5234 );
5235 }
5236 });
5237 }
5238
5239 fn test_summarize_error(
5240 model: &Arc<dyn LanguageModel>,
5241 thread: &Entity<Thread>,
5242 cx: &mut TestAppContext,
5243 ) {
5244 thread.update(cx, |thread, cx| {
5245 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5246 thread.send_to_model(
5247 model.clone(),
5248 CompletionIntent::ThreadSummarization,
5249 None,
5250 cx,
5251 );
5252 });
5253
5254 let fake_model = model.as_fake();
5255 simulate_successful_response(&fake_model, cx);
5256
5257 thread.read_with(cx, |thread, _| {
5258 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5259 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5260 });
5261
5262 // Simulate summary request ending
5263 cx.run_until_parked();
5264 fake_model.end_last_completion_stream();
5265 cx.run_until_parked();
5266
5267 // State is set to Error and default message
5268 thread.read_with(cx, |thread, _| {
5269 assert!(matches!(thread.summary(), ThreadSummary::Error));
5270 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5271 });
5272 }
5273
5274 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5275 cx.run_until_parked();
5276 fake_model.stream_last_completion_response("Assistant response");
5277 fake_model.end_last_completion_stream();
5278 cx.run_until_parked();
5279 }
5280
5281 fn init_test_settings(cx: &mut TestAppContext) {
5282 cx.update(|cx| {
5283 let settings_store = SettingsStore::test(cx);
5284 cx.set_global(settings_store);
5285 language::init(cx);
5286 Project::init_settings(cx);
5287 AgentSettings::register(cx);
5288 prompt_store::init(cx);
5289 thread_store::init(cx);
5290 workspace::init_settings(cx);
5291 language_model::init_settings(cx);
5292 ThemeSettings::register(cx);
5293 ToolRegistry::default_global(cx);
5294 });
5295 }
5296
5297 // Helper to create a test project with test files
5298 async fn create_test_project(
5299 cx: &mut TestAppContext,
5300 files: serde_json::Value,
5301 ) -> Entity<Project> {
5302 let fs = FakeFs::new(cx.executor());
5303 fs.insert_tree(path!("/test"), files).await;
5304 Project::test(fs, [path!("/test").as_ref()], cx).await
5305 }
5306
5307 async fn setup_test_environment(
5308 cx: &mut TestAppContext,
5309 project: Entity<Project>,
5310 ) -> (
5311 Entity<Workspace>,
5312 Entity<ThreadStore>,
5313 Entity<Thread>,
5314 Entity<ContextStore>,
5315 Arc<dyn LanguageModel>,
5316 ) {
5317 let (workspace, cx) =
5318 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5319
5320 let thread_store = cx
5321 .update(|_, cx| {
5322 ThreadStore::load(
5323 project.clone(),
5324 cx.new(|_| ToolWorkingSet::default()),
5325 None,
5326 Arc::new(PromptBuilder::new(None).unwrap()),
5327 cx,
5328 )
5329 })
5330 .await
5331 .unwrap();
5332
5333 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5334 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5335
5336 let provider = Arc::new(FakeLanguageModelProvider);
5337 let model = provider.test_model();
5338 let model: Arc<dyn LanguageModel> = Arc::new(model);
5339
5340 cx.update(|_, cx| {
5341 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5342 registry.set_default_model(
5343 Some(ConfiguredModel {
5344 provider: provider.clone(),
5345 model: model.clone(),
5346 }),
5347 cx,
5348 );
5349 registry.set_thread_summary_model(
5350 Some(ConfiguredModel {
5351 provider,
5352 model: model.clone(),
5353 }),
5354 cx,
5355 );
5356 })
5357 });
5358
5359 (workspace, thread_store, thread, context_store, model)
5360 }
5361
5362 async fn add_file_to_context(
5363 project: &Entity<Project>,
5364 context_store: &Entity<ContextStore>,
5365 path: &str,
5366 cx: &mut TestAppContext,
5367 ) -> Result<Entity<language::Buffer>> {
5368 let buffer_path = project
5369 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5370 .unwrap();
5371
5372 let buffer = project
5373 .update(cx, |project, cx| {
5374 project.open_buffer(buffer_path.clone(), cx)
5375 })
5376 .await
5377 .unwrap();
5378
5379 context_store.update(cx, |context_store, cx| {
5380 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5381 });
5382
5383 Ok(buffer)
5384 }
5385}