1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::HashMap;
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
28 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
29 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
30 TokenUsage,
31};
32use postage::stream::Stream as _;
33use project::{
34 Project,
35 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
36};
37use prompt_store::{ModelContext, PromptBuilder};
38use proto::Plan;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use settings::Settings;
42use std::{
43 io::Write,
44 ops::Range,
45 sync::Arc,
46 time::{Duration, Instant},
47};
48use thiserror::Error;
49use util::{ResultExt as _, debug_panic, post_inc};
50use uuid::Uuid;
51use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
52
53const MAX_RETRY_ATTEMPTS: u8 = 3;
54const BASE_RETRY_DELAY_SECS: u64 = 5;
55
56#[derive(
57 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
58)]
59pub struct ThreadId(Arc<str>);
60
61impl ThreadId {
62 pub fn new() -> Self {
63 Self(Uuid::new_v4().to_string().into())
64 }
65}
66
67impl std::fmt::Display for ThreadId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73impl From<&str> for ThreadId {
74 fn from(value: &str) -> Self {
75 Self(value.into())
76 }
77}
78
79/// The ID of the user prompt that initiated a request.
80///
81/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
82#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
83pub struct PromptId(Arc<str>);
84
85impl PromptId {
86 pub fn new() -> Self {
87 Self(Uuid::new_v4().to_string().into())
88 }
89}
90
91impl std::fmt::Display for PromptId {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{}", self.0)
94 }
95}
96
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
98pub struct MessageId(pub(crate) usize);
99
100impl MessageId {
101 fn post_inc(&mut self) -> Self {
102 Self(post_inc(&mut self.0))
103 }
104
105 pub fn as_usize(&self) -> usize {
106 self.0
107 }
108}
109
110/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
111#[derive(Clone, Debug)]
112pub struct MessageCrease {
113 pub range: Range<usize>,
114 pub icon_path: SharedString,
115 pub label: SharedString,
116 /// None for a deserialized message, Some otherwise.
117 pub context: Option<AgentContextHandle>,
118}
119
120/// A message in a [`Thread`].
121#[derive(Debug, Clone)]
122pub struct Message {
123 pub id: MessageId,
124 pub role: Role,
125 pub segments: Vec<MessageSegment>,
126 pub loaded_context: LoadedContext,
127 pub creases: Vec<MessageCrease>,
128 pub is_hidden: bool,
129 pub ui_only: bool,
130}
131
132impl Message {
133 /// Returns whether the message contains any meaningful text that should be displayed
134 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
135 pub fn should_display_content(&self) -> bool {
136 self.segments.iter().all(|segment| segment.should_display())
137 }
138
139 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
140 if let Some(MessageSegment::Thinking {
141 text: segment,
142 signature: current_signature,
143 }) = self.segments.last_mut()
144 {
145 if let Some(signature) = signature {
146 *current_signature = Some(signature);
147 }
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Thinking {
151 text: text.to_string(),
152 signature,
153 });
154 }
155 }
156
157 pub fn push_redacted_thinking(&mut self, data: String) {
158 self.segments.push(MessageSegment::RedactedThinking(data));
159 }
160
161 pub fn push_text(&mut self, text: &str) {
162 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Text(text.to_string()));
166 }
167 }
168
169 pub fn to_string(&self) -> String {
170 let mut result = String::new();
171
172 if !self.loaded_context.text.is_empty() {
173 result.push_str(&self.loaded_context.text);
174 }
175
176 for segment in &self.segments {
177 match segment {
178 MessageSegment::Text(text) => result.push_str(text),
179 MessageSegment::Thinking { text, .. } => {
180 result.push_str("<think>\n");
181 result.push_str(text);
182 result.push_str("\n</think>");
183 }
184 MessageSegment::RedactedThinking(_) => {}
185 }
186 }
187
188 result
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub enum MessageSegment {
194 Text(String),
195 Thinking {
196 text: String,
197 signature: Option<String>,
198 },
199 RedactedThinking(String),
200}
201
202impl MessageSegment {
203 pub fn should_display(&self) -> bool {
204 match self {
205 Self::Text(text) => text.is_empty(),
206 Self::Thinking { text, .. } => text.is_empty(),
207 Self::RedactedThinking(_) => false,
208 }
209 }
210
211 pub fn text(&self) -> Option<&str> {
212 match self {
213 MessageSegment::Text(text) => Some(text),
214 _ => None,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
220pub struct ProjectSnapshot {
221 pub worktree_snapshots: Vec<WorktreeSnapshot>,
222 pub unsaved_buffer_paths: Vec<String>,
223 pub timestamp: DateTime<Utc>,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227pub struct WorktreeSnapshot {
228 pub worktree_path: String,
229 pub git_state: Option<GitState>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
233pub struct GitState {
234 pub remote_url: Option<String>,
235 pub head_sha: Option<String>,
236 pub current_branch: Option<String>,
237 pub diff: Option<String>,
238}
239
240#[derive(Clone, Debug)]
241pub struct ThreadCheckpoint {
242 message_id: MessageId,
243 git_checkpoint: GitStoreCheckpoint,
244}
245
246#[derive(Copy, Clone, Debug, PartialEq, Eq)]
247pub enum ThreadFeedback {
248 Positive,
249 Negative,
250}
251
252pub enum LastRestoreCheckpoint {
253 Pending {
254 message_id: MessageId,
255 },
256 Error {
257 message_id: MessageId,
258 error: String,
259 },
260}
261
262impl LastRestoreCheckpoint {
263 pub fn message_id(&self) -> MessageId {
264 match self {
265 LastRestoreCheckpoint::Pending { message_id } => *message_id,
266 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
267 }
268 }
269}
270
271#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
272pub enum DetailedSummaryState {
273 #[default]
274 NotGenerated,
275 Generating {
276 message_id: MessageId,
277 },
278 Generated {
279 text: SharedString,
280 message_id: MessageId,
281 },
282}
283
284impl DetailedSummaryState {
285 fn text(&self) -> Option<SharedString> {
286 if let Self::Generated { text, .. } = self {
287 Some(text.clone())
288 } else {
289 None
290 }
291 }
292}
293
294#[derive(Default, Debug)]
295pub struct TotalTokenUsage {
296 pub total: u64,
297 pub max: u64,
298}
299
300impl TotalTokenUsage {
301 pub fn ratio(&self) -> TokenUsageRatio {
302 #[cfg(debug_assertions)]
303 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
304 .unwrap_or("0.8".to_string())
305 .parse()
306 .unwrap();
307 #[cfg(not(debug_assertions))]
308 let warning_threshold: f32 = 0.8;
309
310 // When the maximum is unknown because there is no selected model,
311 // avoid showing the token limit warning.
312 if self.max == 0 {
313 TokenUsageRatio::Normal
314 } else if self.total >= self.max {
315 TokenUsageRatio::Exceeded
316 } else if self.total as f32 / self.max as f32 >= warning_threshold {
317 TokenUsageRatio::Warning
318 } else {
319 TokenUsageRatio::Normal
320 }
321 }
322
323 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
324 TotalTokenUsage {
325 total: self.total + tokens,
326 max: self.max,
327 }
328 }
329}
330
331#[derive(Debug, Default, PartialEq, Eq)]
332pub enum TokenUsageRatio {
333 #[default]
334 Normal,
335 Warning,
336 Exceeded,
337}
338
339#[derive(Debug, Clone, Copy)]
340pub enum QueueState {
341 Sending,
342 Queued { position: usize },
343 Started,
344}
345
346/// A thread of conversation with the LLM.
347pub struct Thread {
348 id: ThreadId,
349 updated_at: DateTime<Utc>,
350 summary: ThreadSummary,
351 pending_summary: Task<Option<()>>,
352 detailed_summary_task: Task<Option<()>>,
353 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
354 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
355 completion_mode: agent_settings::CompletionMode,
356 messages: Vec<Message>,
357 next_message_id: MessageId,
358 last_prompt_id: PromptId,
359 project_context: SharedProjectContext,
360 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
361 completion_count: usize,
362 pending_completions: Vec<PendingCompletion>,
363 project: Entity<Project>,
364 prompt_builder: Arc<PromptBuilder>,
365 tools: Entity<ToolWorkingSet>,
366 tool_use: ToolUseState,
367 action_log: Entity<ActionLog>,
368 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
369 pending_checkpoint: Option<ThreadCheckpoint>,
370 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
371 request_token_usage: Vec<TokenUsage>,
372 cumulative_token_usage: TokenUsage,
373 exceeded_window_error: Option<ExceededWindowError>,
374 tool_use_limit_reached: bool,
375 feedback: Option<ThreadFeedback>,
376 retry_state: Option<RetryState>,
377 message_feedback: HashMap<MessageId, ThreadFeedback>,
378 last_auto_capture_at: Option<Instant>,
379 last_received_chunk_at: Option<Instant>,
380 request_callback: Option<
381 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
382 >,
383 remaining_turns: u32,
384 configured_model: Option<ConfiguredModel>,
385 profile: AgentProfile,
386}
387
388#[derive(Clone, Debug)]
389struct RetryState {
390 attempt: u8,
391 max_attempts: u8,
392 intent: CompletionIntent,
393}
394
395#[derive(Clone, Debug, PartialEq, Eq)]
396pub enum ThreadSummary {
397 Pending,
398 Generating,
399 Ready(SharedString),
400 Error,
401}
402
403impl ThreadSummary {
404 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
405
406 pub fn or_default(&self) -> SharedString {
407 self.unwrap_or(Self::DEFAULT)
408 }
409
410 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
411 self.ready().unwrap_or_else(|| message.into())
412 }
413
414 pub fn ready(&self) -> Option<SharedString> {
415 match self {
416 ThreadSummary::Ready(summary) => Some(summary.clone()),
417 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
418 }
419 }
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
423pub struct ExceededWindowError {
424 /// Model used when last message exceeded context window
425 model_id: LanguageModelId,
426 /// Token count including last message
427 token_count: u64,
428}
429
430impl Thread {
431 pub fn new(
432 project: Entity<Project>,
433 tools: Entity<ToolWorkingSet>,
434 prompt_builder: Arc<PromptBuilder>,
435 system_prompt: SharedProjectContext,
436 cx: &mut Context<Self>,
437 ) -> Self {
438 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
439 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
440 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
441
442 Self {
443 id: ThreadId::new(),
444 updated_at: Utc::now(),
445 summary: ThreadSummary::Pending,
446 pending_summary: Task::ready(None),
447 detailed_summary_task: Task::ready(None),
448 detailed_summary_tx,
449 detailed_summary_rx,
450 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
451 messages: Vec::new(),
452 next_message_id: MessageId(0),
453 last_prompt_id: PromptId::new(),
454 project_context: system_prompt,
455 checkpoints_by_message: HashMap::default(),
456 completion_count: 0,
457 pending_completions: Vec::new(),
458 project: project.clone(),
459 prompt_builder,
460 tools: tools.clone(),
461 last_restore_checkpoint: None,
462 pending_checkpoint: None,
463 tool_use: ToolUseState::new(tools.clone()),
464 action_log: cx.new(|_| ActionLog::new(project.clone())),
465 initial_project_snapshot: {
466 let project_snapshot = Self::project_snapshot(project, cx);
467 cx.foreground_executor()
468 .spawn(async move { Some(project_snapshot.await) })
469 .shared()
470 },
471 request_token_usage: Vec::new(),
472 cumulative_token_usage: TokenUsage::default(),
473 exceeded_window_error: None,
474 tool_use_limit_reached: false,
475 feedback: None,
476 retry_state: None,
477 message_feedback: HashMap::default(),
478 last_auto_capture_at: None,
479 last_received_chunk_at: None,
480 request_callback: None,
481 remaining_turns: u32::MAX,
482 configured_model,
483 profile: AgentProfile::new(profile_id, tools),
484 }
485 }
486
487 pub fn deserialize(
488 id: ThreadId,
489 serialized: SerializedThread,
490 project: Entity<Project>,
491 tools: Entity<ToolWorkingSet>,
492 prompt_builder: Arc<PromptBuilder>,
493 project_context: SharedProjectContext,
494 window: Option<&mut Window>, // None in headless mode
495 cx: &mut Context<Self>,
496 ) -> Self {
497 let next_message_id = MessageId(
498 serialized
499 .messages
500 .last()
501 .map(|message| message.id.0 + 1)
502 .unwrap_or(0),
503 );
504 let tool_use = ToolUseState::from_serialized_messages(
505 tools.clone(),
506 &serialized.messages,
507 project.clone(),
508 window,
509 cx,
510 );
511 let (detailed_summary_tx, detailed_summary_rx) =
512 postage::watch::channel_with(serialized.detailed_summary_state);
513
514 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
515 serialized
516 .model
517 .and_then(|model| {
518 let model = SelectedModel {
519 provider: model.provider.clone().into(),
520 model: model.model.clone().into(),
521 };
522 registry.select_model(&model, cx)
523 })
524 .or_else(|| registry.default_model())
525 });
526
527 let completion_mode = serialized
528 .completion_mode
529 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
530 let profile_id = serialized
531 .profile
532 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
533
534 Self {
535 id,
536 updated_at: serialized.updated_at,
537 summary: ThreadSummary::Ready(serialized.summary),
538 pending_summary: Task::ready(None),
539 detailed_summary_task: Task::ready(None),
540 detailed_summary_tx,
541 detailed_summary_rx,
542 completion_mode,
543 retry_state: None,
544 messages: serialized
545 .messages
546 .into_iter()
547 .map(|message| Message {
548 id: message.id,
549 role: message.role,
550 segments: message
551 .segments
552 .into_iter()
553 .map(|segment| match segment {
554 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
555 SerializedMessageSegment::Thinking { text, signature } => {
556 MessageSegment::Thinking { text, signature }
557 }
558 SerializedMessageSegment::RedactedThinking { data } => {
559 MessageSegment::RedactedThinking(data)
560 }
561 })
562 .collect(),
563 loaded_context: LoadedContext {
564 contexts: Vec::new(),
565 text: message.context,
566 images: Vec::new(),
567 },
568 creases: message
569 .creases
570 .into_iter()
571 .map(|crease| MessageCrease {
572 range: crease.start..crease.end,
573 icon_path: crease.icon_path,
574 label: crease.label,
575 context: None,
576 })
577 .collect(),
578 is_hidden: message.is_hidden,
579 ui_only: false, // UI-only messages are not persisted
580 })
581 .collect(),
582 next_message_id,
583 last_prompt_id: PromptId::new(),
584 project_context,
585 checkpoints_by_message: HashMap::default(),
586 completion_count: 0,
587 pending_completions: Vec::new(),
588 last_restore_checkpoint: None,
589 pending_checkpoint: None,
590 project: project.clone(),
591 prompt_builder,
592 tools: tools.clone(),
593 tool_use,
594 action_log: cx.new(|_| ActionLog::new(project)),
595 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
596 request_token_usage: serialized.request_token_usage,
597 cumulative_token_usage: serialized.cumulative_token_usage,
598 exceeded_window_error: None,
599 tool_use_limit_reached: serialized.tool_use_limit_reached,
600 feedback: None,
601 message_feedback: HashMap::default(),
602 last_auto_capture_at: None,
603 last_received_chunk_at: None,
604 request_callback: None,
605 remaining_turns: u32::MAX,
606 configured_model,
607 profile: AgentProfile::new(profile_id, tools),
608 }
609 }
610
611 pub fn set_request_callback(
612 &mut self,
613 callback: impl 'static
614 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
615 ) {
616 self.request_callback = Some(Box::new(callback));
617 }
618
619 pub fn id(&self) -> &ThreadId {
620 &self.id
621 }
622
623 pub fn profile(&self) -> &AgentProfile {
624 &self.profile
625 }
626
627 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
628 if &id != self.profile.id() {
629 self.profile = AgentProfile::new(id, self.tools.clone());
630 cx.emit(ThreadEvent::ProfileChanged);
631 }
632 }
633
634 pub fn is_empty(&self) -> bool {
635 self.messages.is_empty()
636 }
637
638 pub fn updated_at(&self) -> DateTime<Utc> {
639 self.updated_at
640 }
641
642 pub fn touch_updated_at(&mut self) {
643 self.updated_at = Utc::now();
644 }
645
646 pub fn advance_prompt_id(&mut self) {
647 self.last_prompt_id = PromptId::new();
648 }
649
650 pub fn project_context(&self) -> SharedProjectContext {
651 self.project_context.clone()
652 }
653
654 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
655 if self.configured_model.is_none() {
656 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
657 }
658 self.configured_model.clone()
659 }
660
661 pub fn configured_model(&self) -> Option<ConfiguredModel> {
662 self.configured_model.clone()
663 }
664
665 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
666 self.configured_model = model;
667 cx.notify();
668 }
669
670 pub fn summary(&self) -> &ThreadSummary {
671 &self.summary
672 }
673
674 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
675 let current_summary = match &self.summary {
676 ThreadSummary::Pending | ThreadSummary::Generating => return,
677 ThreadSummary::Ready(summary) => summary,
678 ThreadSummary::Error => &ThreadSummary::DEFAULT,
679 };
680
681 let mut new_summary = new_summary.into();
682
683 if new_summary.is_empty() {
684 new_summary = ThreadSummary::DEFAULT;
685 }
686
687 if current_summary != &new_summary {
688 self.summary = ThreadSummary::Ready(new_summary);
689 cx.emit(ThreadEvent::SummaryChanged);
690 }
691 }
692
693 pub fn completion_mode(&self) -> CompletionMode {
694 self.completion_mode
695 }
696
697 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
698 self.completion_mode = mode;
699 }
700
701 pub fn message(&self, id: MessageId) -> Option<&Message> {
702 let index = self
703 .messages
704 .binary_search_by(|message| message.id.cmp(&id))
705 .ok()?;
706
707 self.messages.get(index)
708 }
709
710 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
711 self.messages.iter()
712 }
713
714 pub fn is_generating(&self) -> bool {
715 !self.pending_completions.is_empty() || !self.all_tools_finished()
716 }
717
718 /// Indicates whether streaming of language model events is stale.
719 /// When `is_generating()` is false, this method returns `None`.
720 pub fn is_generation_stale(&self) -> Option<bool> {
721 const STALE_THRESHOLD: u128 = 250;
722
723 self.last_received_chunk_at
724 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
725 }
726
727 fn received_chunk(&mut self) {
728 self.last_received_chunk_at = Some(Instant::now());
729 }
730
731 pub fn queue_state(&self) -> Option<QueueState> {
732 self.pending_completions
733 .first()
734 .map(|pending_completion| pending_completion.queue_state)
735 }
736
737 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
738 &self.tools
739 }
740
741 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
742 self.tool_use
743 .pending_tool_uses()
744 .into_iter()
745 .find(|tool_use| &tool_use.id == id)
746 }
747
748 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
749 self.tool_use
750 .pending_tool_uses()
751 .into_iter()
752 .filter(|tool_use| tool_use.status.needs_confirmation())
753 }
754
755 pub fn has_pending_tool_uses(&self) -> bool {
756 !self.tool_use.pending_tool_uses().is_empty()
757 }
758
759 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
760 self.checkpoints_by_message.get(&id).cloned()
761 }
762
763 pub fn restore_checkpoint(
764 &mut self,
765 checkpoint: ThreadCheckpoint,
766 cx: &mut Context<Self>,
767 ) -> Task<Result<()>> {
768 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
769 message_id: checkpoint.message_id,
770 });
771 cx.emit(ThreadEvent::CheckpointChanged);
772 cx.notify();
773
774 let git_store = self.project().read(cx).git_store().clone();
775 let restore = git_store.update(cx, |git_store, cx| {
776 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
777 });
778
779 cx.spawn(async move |this, cx| {
780 let result = restore.await;
781 this.update(cx, |this, cx| {
782 if let Err(err) = result.as_ref() {
783 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
784 message_id: checkpoint.message_id,
785 error: err.to_string(),
786 });
787 } else {
788 this.truncate(checkpoint.message_id, cx);
789 this.last_restore_checkpoint = None;
790 }
791 this.pending_checkpoint = None;
792 cx.emit(ThreadEvent::CheckpointChanged);
793 cx.notify();
794 })?;
795 result
796 })
797 }
798
799 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
800 let pending_checkpoint = if self.is_generating() {
801 return;
802 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
803 checkpoint
804 } else {
805 return;
806 };
807
808 self.finalize_checkpoint(pending_checkpoint, cx);
809 }
810
811 fn finalize_checkpoint(
812 &mut self,
813 pending_checkpoint: ThreadCheckpoint,
814 cx: &mut Context<Self>,
815 ) {
816 let git_store = self.project.read(cx).git_store().clone();
817 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
818 cx.spawn(async move |this, cx| match final_checkpoint.await {
819 Ok(final_checkpoint) => {
820 let equal = git_store
821 .update(cx, |store, cx| {
822 store.compare_checkpoints(
823 pending_checkpoint.git_checkpoint.clone(),
824 final_checkpoint.clone(),
825 cx,
826 )
827 })?
828 .await
829 .unwrap_or(false);
830
831 if !equal {
832 this.update(cx, |this, cx| {
833 this.insert_checkpoint(pending_checkpoint, cx)
834 })?;
835 }
836
837 Ok(())
838 }
839 Err(_) => this.update(cx, |this, cx| {
840 this.insert_checkpoint(pending_checkpoint, cx)
841 }),
842 })
843 .detach();
844 }
845
846 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
847 self.checkpoints_by_message
848 .insert(checkpoint.message_id, checkpoint);
849 cx.emit(ThreadEvent::CheckpointChanged);
850 cx.notify();
851 }
852
853 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
854 self.last_restore_checkpoint.as_ref()
855 }
856
857 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
858 let Some(message_ix) = self
859 .messages
860 .iter()
861 .rposition(|message| message.id == message_id)
862 else {
863 return;
864 };
865 for deleted_message in self.messages.drain(message_ix..) {
866 self.checkpoints_by_message.remove(&deleted_message.id);
867 }
868 cx.notify();
869 }
870
871 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
872 self.messages
873 .iter()
874 .find(|message| message.id == id)
875 .into_iter()
876 .flat_map(|message| message.loaded_context.contexts.iter())
877 }
878
879 pub fn is_turn_end(&self, ix: usize) -> bool {
880 if self.messages.is_empty() {
881 return false;
882 }
883
884 if !self.is_generating() && ix == self.messages.len() - 1 {
885 return true;
886 }
887
888 let Some(message) = self.messages.get(ix) else {
889 return false;
890 };
891
892 if message.role != Role::Assistant {
893 return false;
894 }
895
896 self.messages
897 .get(ix + 1)
898 .and_then(|message| {
899 self.message(message.id)
900 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
901 })
902 .unwrap_or(false)
903 }
904
905 pub fn tool_use_limit_reached(&self) -> bool {
906 self.tool_use_limit_reached
907 }
908
909 /// Returns whether all of the tool uses have finished running.
910 pub fn all_tools_finished(&self) -> bool {
911 // If the only pending tool uses left are the ones with errors, then
912 // that means that we've finished running all of the pending tools.
913 self.tool_use
914 .pending_tool_uses()
915 .iter()
916 .all(|pending_tool_use| pending_tool_use.status.is_error())
917 }
918
919 /// Returns whether any pending tool uses may perform edits
920 pub fn has_pending_edit_tool_uses(&self) -> bool {
921 self.tool_use
922 .pending_tool_uses()
923 .iter()
924 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
925 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
926 }
927
928 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
929 self.tool_use.tool_uses_for_message(id, cx)
930 }
931
932 pub fn tool_results_for_message(
933 &self,
934 assistant_message_id: MessageId,
935 ) -> Vec<&LanguageModelToolResult> {
936 self.tool_use.tool_results_for_message(assistant_message_id)
937 }
938
939 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
940 self.tool_use.tool_result(id)
941 }
942
943 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
944 match &self.tool_use.tool_result(id)?.content {
945 LanguageModelToolResultContent::Text(text) => Some(text),
946 LanguageModelToolResultContent::Image(_) => {
947 // TODO: We should display image
948 None
949 }
950 }
951 }
952
953 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
954 self.tool_use.tool_result_card(id).cloned()
955 }
956
957 /// Return tools that are both enabled and supported by the model
958 pub fn available_tools(
959 &self,
960 cx: &App,
961 model: Arc<dyn LanguageModel>,
962 ) -> Vec<LanguageModelRequestTool> {
963 if model.supports_tools() {
964 self.profile
965 .enabled_tools(cx)
966 .into_iter()
967 .filter_map(|(name, tool)| {
968 // Skip tools that cannot be supported
969 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
970 Some(LanguageModelRequestTool {
971 name: name.into(),
972 description: tool.description(),
973 input_schema,
974 })
975 })
976 .collect()
977 } else {
978 Vec::default()
979 }
980 }
981
982 pub fn insert_user_message(
983 &mut self,
984 text: impl Into<String>,
985 loaded_context: ContextLoadResult,
986 git_checkpoint: Option<GitStoreCheckpoint>,
987 creases: Vec<MessageCrease>,
988 cx: &mut Context<Self>,
989 ) -> MessageId {
990 if !loaded_context.referenced_buffers.is_empty() {
991 self.action_log.update(cx, |log, cx| {
992 for buffer in loaded_context.referenced_buffers {
993 log.buffer_read(buffer, cx);
994 }
995 });
996 }
997
998 let message_id = self.insert_message(
999 Role::User,
1000 vec![MessageSegment::Text(text.into())],
1001 loaded_context.loaded_context,
1002 creases,
1003 false,
1004 cx,
1005 );
1006
1007 if let Some(git_checkpoint) = git_checkpoint {
1008 self.pending_checkpoint = Some(ThreadCheckpoint {
1009 message_id,
1010 git_checkpoint,
1011 });
1012 }
1013
1014 self.auto_capture_telemetry(cx);
1015
1016 message_id
1017 }
1018
1019 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1020 let id = self.insert_message(
1021 Role::User,
1022 vec![MessageSegment::Text("Continue where you left off".into())],
1023 LoadedContext::default(),
1024 vec![],
1025 true,
1026 cx,
1027 );
1028 self.pending_checkpoint = None;
1029
1030 id
1031 }
1032
1033 pub fn insert_assistant_message(
1034 &mut self,
1035 segments: Vec<MessageSegment>,
1036 cx: &mut Context<Self>,
1037 ) -> MessageId {
1038 self.insert_message(
1039 Role::Assistant,
1040 segments,
1041 LoadedContext::default(),
1042 Vec::new(),
1043 false,
1044 cx,
1045 )
1046 }
1047
1048 pub fn insert_message(
1049 &mut self,
1050 role: Role,
1051 segments: Vec<MessageSegment>,
1052 loaded_context: LoadedContext,
1053 creases: Vec<MessageCrease>,
1054 is_hidden: bool,
1055 cx: &mut Context<Self>,
1056 ) -> MessageId {
1057 let id = self.next_message_id.post_inc();
1058 self.messages.push(Message {
1059 id,
1060 role,
1061 segments,
1062 loaded_context,
1063 creases,
1064 is_hidden,
1065 ui_only: false,
1066 });
1067 self.touch_updated_at();
1068 cx.emit(ThreadEvent::MessageAdded(id));
1069 id
1070 }
1071
1072 pub fn edit_message(
1073 &mut self,
1074 id: MessageId,
1075 new_role: Role,
1076 new_segments: Vec<MessageSegment>,
1077 creases: Vec<MessageCrease>,
1078 loaded_context: Option<LoadedContext>,
1079 checkpoint: Option<GitStoreCheckpoint>,
1080 cx: &mut Context<Self>,
1081 ) -> bool {
1082 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1083 return false;
1084 };
1085 message.role = new_role;
1086 message.segments = new_segments;
1087 message.creases = creases;
1088 if let Some(context) = loaded_context {
1089 message.loaded_context = context;
1090 }
1091 if let Some(git_checkpoint) = checkpoint {
1092 self.checkpoints_by_message.insert(
1093 id,
1094 ThreadCheckpoint {
1095 message_id: id,
1096 git_checkpoint,
1097 },
1098 );
1099 }
1100 self.touch_updated_at();
1101 cx.emit(ThreadEvent::MessageEdited(id));
1102 true
1103 }
1104
1105 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1106 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1107 return false;
1108 };
1109 self.messages.remove(index);
1110 self.touch_updated_at();
1111 cx.emit(ThreadEvent::MessageDeleted(id));
1112 true
1113 }
1114
1115 /// Returns the representation of this [`Thread`] in a textual form.
1116 ///
1117 /// This is the representation we use when attaching a thread as context to another thread.
1118 pub fn text(&self) -> String {
1119 let mut text = String::new();
1120
1121 for message in &self.messages {
1122 text.push_str(match message.role {
1123 language_model::Role::User => "User:",
1124 language_model::Role::Assistant => "Agent:",
1125 language_model::Role::System => "System:",
1126 });
1127 text.push('\n');
1128
1129 for segment in &message.segments {
1130 match segment {
1131 MessageSegment::Text(content) => text.push_str(content),
1132 MessageSegment::Thinking { text: content, .. } => {
1133 text.push_str(&format!("<think>{}</think>", content))
1134 }
1135 MessageSegment::RedactedThinking(_) => {}
1136 }
1137 }
1138 text.push('\n');
1139 }
1140
1141 text
1142 }
1143
1144 /// Serializes this thread into a format for storage or telemetry.
1145 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1146 let initial_project_snapshot = self.initial_project_snapshot.clone();
1147 cx.spawn(async move |this, cx| {
1148 let initial_project_snapshot = initial_project_snapshot.await;
1149 this.read_with(cx, |this, cx| SerializedThread {
1150 version: SerializedThread::VERSION.to_string(),
1151 summary: this.summary().or_default(),
1152 updated_at: this.updated_at(),
1153 messages: this
1154 .messages()
1155 .filter(|message| !message.ui_only)
1156 .map(|message| SerializedMessage {
1157 id: message.id,
1158 role: message.role,
1159 segments: message
1160 .segments
1161 .iter()
1162 .map(|segment| match segment {
1163 MessageSegment::Text(text) => {
1164 SerializedMessageSegment::Text { text: text.clone() }
1165 }
1166 MessageSegment::Thinking { text, signature } => {
1167 SerializedMessageSegment::Thinking {
1168 text: text.clone(),
1169 signature: signature.clone(),
1170 }
1171 }
1172 MessageSegment::RedactedThinking(data) => {
1173 SerializedMessageSegment::RedactedThinking {
1174 data: data.clone(),
1175 }
1176 }
1177 })
1178 .collect(),
1179 tool_uses: this
1180 .tool_uses_for_message(message.id, cx)
1181 .into_iter()
1182 .map(|tool_use| SerializedToolUse {
1183 id: tool_use.id,
1184 name: tool_use.name,
1185 input: tool_use.input,
1186 })
1187 .collect(),
1188 tool_results: this
1189 .tool_results_for_message(message.id)
1190 .into_iter()
1191 .map(|tool_result| SerializedToolResult {
1192 tool_use_id: tool_result.tool_use_id.clone(),
1193 is_error: tool_result.is_error,
1194 content: tool_result.content.clone(),
1195 output: tool_result.output.clone(),
1196 })
1197 .collect(),
1198 context: message.loaded_context.text.clone(),
1199 creases: message
1200 .creases
1201 .iter()
1202 .map(|crease| SerializedCrease {
1203 start: crease.range.start,
1204 end: crease.range.end,
1205 icon_path: crease.icon_path.clone(),
1206 label: crease.label.clone(),
1207 })
1208 .collect(),
1209 is_hidden: message.is_hidden,
1210 })
1211 .collect(),
1212 initial_project_snapshot,
1213 cumulative_token_usage: this.cumulative_token_usage,
1214 request_token_usage: this.request_token_usage.clone(),
1215 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1216 exceeded_window_error: this.exceeded_window_error.clone(),
1217 model: this
1218 .configured_model
1219 .as_ref()
1220 .map(|model| SerializedLanguageModel {
1221 provider: model.provider.id().0.to_string(),
1222 model: model.model.id().0.to_string(),
1223 }),
1224 completion_mode: Some(this.completion_mode),
1225 tool_use_limit_reached: this.tool_use_limit_reached,
1226 profile: Some(this.profile.id().clone()),
1227 })
1228 })
1229 }
1230
1231 pub fn remaining_turns(&self) -> u32 {
1232 self.remaining_turns
1233 }
1234
1235 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1236 self.remaining_turns = remaining_turns;
1237 }
1238
1239 pub fn send_to_model(
1240 &mut self,
1241 model: Arc<dyn LanguageModel>,
1242 intent: CompletionIntent,
1243 window: Option<AnyWindowHandle>,
1244 cx: &mut Context<Self>,
1245 ) {
1246 if self.remaining_turns == 0 {
1247 return;
1248 }
1249
1250 self.remaining_turns -= 1;
1251
1252 self.flush_notifications(model.clone(), intent, cx);
1253
1254 let request = self.to_completion_request(model.clone(), intent, cx);
1255
1256 self.stream_completion(request, model, intent, window, cx);
1257 }
1258
1259 pub fn used_tools_since_last_user_message(&self) -> bool {
1260 for message in self.messages.iter().rev() {
1261 if self.tool_use.message_has_tool_results(message.id) {
1262 return true;
1263 } else if message.role == Role::User {
1264 return false;
1265 }
1266 }
1267
1268 false
1269 }
1270
1271 pub fn to_completion_request(
1272 &self,
1273 model: Arc<dyn LanguageModel>,
1274 intent: CompletionIntent,
1275 cx: &mut Context<Self>,
1276 ) -> LanguageModelRequest {
1277 let mut request = LanguageModelRequest {
1278 thread_id: Some(self.id.to_string()),
1279 prompt_id: Some(self.last_prompt_id.to_string()),
1280 intent: Some(intent),
1281 mode: None,
1282 messages: vec![],
1283 tools: Vec::new(),
1284 tool_choice: None,
1285 stop: Vec::new(),
1286 temperature: AgentSettings::temperature_for_model(&model, cx),
1287 thinking_allowed: true,
1288 };
1289
1290 let available_tools = self.available_tools(cx, model.clone());
1291 let available_tool_names = available_tools
1292 .iter()
1293 .map(|tool| tool.name.clone())
1294 .collect();
1295
1296 let model_context = &ModelContext {
1297 available_tools: available_tool_names,
1298 };
1299
1300 if let Some(project_context) = self.project_context.borrow().as_ref() {
1301 match self
1302 .prompt_builder
1303 .generate_assistant_system_prompt(project_context, model_context)
1304 {
1305 Err(err) => {
1306 let message = format!("{err:?}").into();
1307 log::error!("{message}");
1308 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1309 header: "Error generating system prompt".into(),
1310 message,
1311 }));
1312 }
1313 Ok(system_prompt) => {
1314 request.messages.push(LanguageModelRequestMessage {
1315 role: Role::System,
1316 content: vec![MessageContent::Text(system_prompt)],
1317 cache: true,
1318 });
1319 }
1320 }
1321 } else {
1322 let message = "Context for system prompt unexpectedly not ready.".into();
1323 log::error!("{message}");
1324 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1325 header: "Error generating system prompt".into(),
1326 message,
1327 }));
1328 }
1329
1330 let mut message_ix_to_cache = None;
1331 for message in &self.messages {
1332 // ui_only messages are for the UI only, not for the model
1333 if message.ui_only {
1334 continue;
1335 }
1336
1337 let mut request_message = LanguageModelRequestMessage {
1338 role: message.role,
1339 content: Vec::new(),
1340 cache: false,
1341 };
1342
1343 message
1344 .loaded_context
1345 .add_to_request_message(&mut request_message);
1346
1347 for segment in &message.segments {
1348 match segment {
1349 MessageSegment::Text(text) => {
1350 let text = text.trim_end();
1351 if !text.is_empty() {
1352 request_message
1353 .content
1354 .push(MessageContent::Text(text.into()));
1355 }
1356 }
1357 MessageSegment::Thinking { text, signature } => {
1358 if !text.is_empty() {
1359 request_message.content.push(MessageContent::Thinking {
1360 text: text.into(),
1361 signature: signature.clone(),
1362 });
1363 }
1364 }
1365 MessageSegment::RedactedThinking(data) => {
1366 request_message
1367 .content
1368 .push(MessageContent::RedactedThinking(data.clone()));
1369 }
1370 };
1371 }
1372
1373 let mut cache_message = true;
1374 let mut tool_results_message = LanguageModelRequestMessage {
1375 role: Role::User,
1376 content: Vec::new(),
1377 cache: false,
1378 };
1379 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1380 if let Some(tool_result) = tool_result {
1381 request_message
1382 .content
1383 .push(MessageContent::ToolUse(tool_use.clone()));
1384 tool_results_message
1385 .content
1386 .push(MessageContent::ToolResult(LanguageModelToolResult {
1387 tool_use_id: tool_use.id.clone(),
1388 tool_name: tool_result.tool_name.clone(),
1389 is_error: tool_result.is_error,
1390 content: if tool_result.content.is_empty() {
1391 // Surprisingly, the API fails if we return an empty string here.
1392 // It thinks we are sending a tool use without a tool result.
1393 "<Tool returned an empty string>".into()
1394 } else {
1395 tool_result.content.clone()
1396 },
1397 output: None,
1398 }));
1399 } else {
1400 cache_message = false;
1401 log::debug!(
1402 "skipped tool use {:?} because it is still pending",
1403 tool_use
1404 );
1405 }
1406 }
1407
1408 if cache_message {
1409 message_ix_to_cache = Some(request.messages.len());
1410 }
1411 request.messages.push(request_message);
1412
1413 if !tool_results_message.content.is_empty() {
1414 if cache_message {
1415 message_ix_to_cache = Some(request.messages.len());
1416 }
1417 request.messages.push(tool_results_message);
1418 }
1419 }
1420
1421 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1422 if let Some(message_ix_to_cache) = message_ix_to_cache {
1423 request.messages[message_ix_to_cache].cache = true;
1424 }
1425
1426 request.tools = available_tools;
1427 request.mode = if model.supports_burn_mode() {
1428 Some(self.completion_mode.into())
1429 } else {
1430 Some(CompletionMode::Normal.into())
1431 };
1432
1433 request
1434 }
1435
1436 fn to_summarize_request(
1437 &self,
1438 model: &Arc<dyn LanguageModel>,
1439 intent: CompletionIntent,
1440 added_user_message: String,
1441 cx: &App,
1442 ) -> LanguageModelRequest {
1443 let mut request = LanguageModelRequest {
1444 thread_id: None,
1445 prompt_id: None,
1446 intent: Some(intent),
1447 mode: None,
1448 messages: vec![],
1449 tools: Vec::new(),
1450 tool_choice: None,
1451 stop: Vec::new(),
1452 temperature: AgentSettings::temperature_for_model(model, cx),
1453 thinking_allowed: false,
1454 };
1455
1456 for message in &self.messages {
1457 let mut request_message = LanguageModelRequestMessage {
1458 role: message.role,
1459 content: Vec::new(),
1460 cache: false,
1461 };
1462
1463 for segment in &message.segments {
1464 match segment {
1465 MessageSegment::Text(text) => request_message
1466 .content
1467 .push(MessageContent::Text(text.clone())),
1468 MessageSegment::Thinking { .. } => {}
1469 MessageSegment::RedactedThinking(_) => {}
1470 }
1471 }
1472
1473 if request_message.content.is_empty() {
1474 continue;
1475 }
1476
1477 request.messages.push(request_message);
1478 }
1479
1480 request.messages.push(LanguageModelRequestMessage {
1481 role: Role::User,
1482 content: vec![MessageContent::Text(added_user_message)],
1483 cache: false,
1484 });
1485
1486 request
1487 }
1488
1489 /// Insert auto-generated notifications (if any) to the thread
1490 fn flush_notifications(
1491 &mut self,
1492 model: Arc<dyn LanguageModel>,
1493 intent: CompletionIntent,
1494 cx: &mut Context<Self>,
1495 ) {
1496 match intent {
1497 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1498 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1499 cx.emit(ThreadEvent::ToolFinished {
1500 tool_use_id: pending_tool_use.id.clone(),
1501 pending_tool_use: Some(pending_tool_use),
1502 });
1503 }
1504 }
1505 CompletionIntent::ThreadSummarization
1506 | CompletionIntent::ThreadContextSummarization
1507 | CompletionIntent::CreateFile
1508 | CompletionIntent::EditFile
1509 | CompletionIntent::InlineAssist
1510 | CompletionIntent::TerminalInlineAssist
1511 | CompletionIntent::GenerateGitCommitMessage => {}
1512 };
1513 }
1514
1515 fn attach_tracked_files_state(
1516 &mut self,
1517 model: Arc<dyn LanguageModel>,
1518 cx: &mut App,
1519 ) -> Option<PendingToolUse> {
1520 let action_log = self.action_log.read(cx);
1521
1522 action_log.unnotified_stale_buffers(cx).next()?;
1523
1524 // Represent notification as a simulated `project_notifications` tool call
1525 let tool_name = Arc::from("project_notifications");
1526 let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1527 debug_panic!("`project_notifications` tool not found");
1528 return None;
1529 };
1530
1531 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1532 return None;
1533 }
1534
1535 let input = serde_json::json!({});
1536 let request = Arc::new(LanguageModelRequest::default()); // unused
1537 let window = None;
1538 let tool_result = tool.run(
1539 input,
1540 request,
1541 self.project.clone(),
1542 self.action_log.clone(),
1543 model.clone(),
1544 window,
1545 cx,
1546 );
1547
1548 let tool_use_id =
1549 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1550
1551 let tool_use = LanguageModelToolUse {
1552 id: tool_use_id.clone(),
1553 name: tool_name.clone(),
1554 raw_input: "{}".to_string(),
1555 input: serde_json::json!({}),
1556 is_input_complete: true,
1557 };
1558
1559 let tool_output = cx.background_executor().block(tool_result.output);
1560
1561 // Attach a project_notification tool call to the latest existing
1562 // Assistant message. We cannot create a new Assistant message
1563 // because thinking models require a `thinking` block that we
1564 // cannot mock. We cannot send a notification as a normal
1565 // (non-tool-use) User message because this distracts Agent
1566 // too much.
1567 let tool_message_id = self
1568 .messages
1569 .iter()
1570 .enumerate()
1571 .rfind(|(_, message)| message.role == Role::Assistant)
1572 .map(|(_, message)| message.id)?;
1573
1574 let tool_use_metadata = ToolUseMetadata {
1575 model: model.clone(),
1576 thread_id: self.id.clone(),
1577 prompt_id: self.last_prompt_id.clone(),
1578 };
1579
1580 self.tool_use
1581 .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1582
1583 let pending_tool_use = self.tool_use.insert_tool_output(
1584 tool_use_id.clone(),
1585 tool_name,
1586 tool_output,
1587 self.configured_model.as_ref(),
1588 self.completion_mode,
1589 );
1590
1591 pending_tool_use
1592 }
1593
1594 pub fn stream_completion(
1595 &mut self,
1596 request: LanguageModelRequest,
1597 model: Arc<dyn LanguageModel>,
1598 intent: CompletionIntent,
1599 window: Option<AnyWindowHandle>,
1600 cx: &mut Context<Self>,
1601 ) {
1602 self.tool_use_limit_reached = false;
1603
1604 let pending_completion_id = post_inc(&mut self.completion_count);
1605 let mut request_callback_parameters = if self.request_callback.is_some() {
1606 Some((request.clone(), Vec::new()))
1607 } else {
1608 None
1609 };
1610 let prompt_id = self.last_prompt_id.clone();
1611 let tool_use_metadata = ToolUseMetadata {
1612 model: model.clone(),
1613 thread_id: self.id.clone(),
1614 prompt_id: prompt_id.clone(),
1615 };
1616
1617 let completion_mode = request
1618 .mode
1619 .unwrap_or(zed_llm_client::CompletionMode::Normal);
1620
1621 self.last_received_chunk_at = Some(Instant::now());
1622
1623 let task = cx.spawn(async move |thread, cx| {
1624 let stream_completion_future = model.stream_completion(request, &cx);
1625 let initial_token_usage =
1626 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1627 let stream_completion = async {
1628 let mut events = stream_completion_future.await?;
1629
1630 let mut stop_reason = StopReason::EndTurn;
1631 let mut current_token_usage = TokenUsage::default();
1632
1633 thread
1634 .update(cx, |_thread, cx| {
1635 cx.emit(ThreadEvent::NewRequest);
1636 })
1637 .ok();
1638
1639 let mut request_assistant_message_id = None;
1640
1641 while let Some(event) = events.next().await {
1642 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1643 response_events
1644 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1645 }
1646
1647 thread.update(cx, |thread, cx| {
1648 match event? {
1649 LanguageModelCompletionEvent::StartMessage { .. } => {
1650 request_assistant_message_id =
1651 Some(thread.insert_assistant_message(
1652 vec![MessageSegment::Text(String::new())],
1653 cx,
1654 ));
1655 }
1656 LanguageModelCompletionEvent::Stop(reason) => {
1657 stop_reason = reason;
1658 }
1659 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1660 thread.update_token_usage_at_last_message(token_usage);
1661 thread.cumulative_token_usage = thread.cumulative_token_usage
1662 + token_usage
1663 - current_token_usage;
1664 current_token_usage = token_usage;
1665 }
1666 LanguageModelCompletionEvent::Text(chunk) => {
1667 thread.received_chunk();
1668
1669 cx.emit(ThreadEvent::ReceivedTextChunk);
1670 if let Some(last_message) = thread.messages.last_mut() {
1671 if last_message.role == Role::Assistant
1672 && !thread.tool_use.has_tool_results(last_message.id)
1673 {
1674 last_message.push_text(&chunk);
1675 cx.emit(ThreadEvent::StreamedAssistantText(
1676 last_message.id,
1677 chunk,
1678 ));
1679 } else {
1680 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1681 // of a new Assistant response.
1682 //
1683 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1684 // will result in duplicating the text of the chunk in the rendered Markdown.
1685 request_assistant_message_id =
1686 Some(thread.insert_assistant_message(
1687 vec![MessageSegment::Text(chunk.to_string())],
1688 cx,
1689 ));
1690 };
1691 }
1692 }
1693 LanguageModelCompletionEvent::Thinking {
1694 text: chunk,
1695 signature,
1696 } => {
1697 thread.received_chunk();
1698
1699 if let Some(last_message) = thread.messages.last_mut() {
1700 if last_message.role == Role::Assistant
1701 && !thread.tool_use.has_tool_results(last_message.id)
1702 {
1703 last_message.push_thinking(&chunk, signature);
1704 cx.emit(ThreadEvent::StreamedAssistantThinking(
1705 last_message.id,
1706 chunk,
1707 ));
1708 } else {
1709 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1710 // of a new Assistant response.
1711 //
1712 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1713 // will result in duplicating the text of the chunk in the rendered Markdown.
1714 request_assistant_message_id =
1715 Some(thread.insert_assistant_message(
1716 vec![MessageSegment::Thinking {
1717 text: chunk.to_string(),
1718 signature,
1719 }],
1720 cx,
1721 ));
1722 };
1723 }
1724 }
1725 LanguageModelCompletionEvent::RedactedThinking { data } => {
1726 thread.received_chunk();
1727
1728 if let Some(last_message) = thread.messages.last_mut() {
1729 if last_message.role == Role::Assistant
1730 && !thread.tool_use.has_tool_results(last_message.id)
1731 {
1732 last_message.push_redacted_thinking(data);
1733 } else {
1734 request_assistant_message_id =
1735 Some(thread.insert_assistant_message(
1736 vec![MessageSegment::RedactedThinking(data)],
1737 cx,
1738 ));
1739 };
1740 }
1741 }
1742 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1743 let last_assistant_message_id = request_assistant_message_id
1744 .unwrap_or_else(|| {
1745 let new_assistant_message_id =
1746 thread.insert_assistant_message(vec![], cx);
1747 request_assistant_message_id =
1748 Some(new_assistant_message_id);
1749 new_assistant_message_id
1750 });
1751
1752 let tool_use_id = tool_use.id.clone();
1753 let streamed_input = if tool_use.is_input_complete {
1754 None
1755 } else {
1756 Some((&tool_use.input).clone())
1757 };
1758
1759 let ui_text = thread.tool_use.request_tool_use(
1760 last_assistant_message_id,
1761 tool_use,
1762 tool_use_metadata.clone(),
1763 cx,
1764 );
1765
1766 if let Some(input) = streamed_input {
1767 cx.emit(ThreadEvent::StreamedToolUse {
1768 tool_use_id,
1769 ui_text,
1770 input,
1771 });
1772 }
1773 }
1774 LanguageModelCompletionEvent::ToolUseJsonParseError {
1775 id,
1776 tool_name,
1777 raw_input: invalid_input_json,
1778 json_parse_error,
1779 } => {
1780 thread.receive_invalid_tool_json(
1781 id,
1782 tool_name,
1783 invalid_input_json,
1784 json_parse_error,
1785 window,
1786 cx,
1787 );
1788 }
1789 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1790 if let Some(completion) = thread
1791 .pending_completions
1792 .iter_mut()
1793 .find(|completion| completion.id == pending_completion_id)
1794 {
1795 match status_update {
1796 CompletionRequestStatus::Queued { position } => {
1797 completion.queue_state =
1798 QueueState::Queued { position };
1799 }
1800 CompletionRequestStatus::Started => {
1801 completion.queue_state = QueueState::Started;
1802 }
1803 CompletionRequestStatus::Failed {
1804 code,
1805 message,
1806 request_id: _,
1807 retry_after,
1808 } => {
1809 return Err(
1810 LanguageModelCompletionError::from_cloud_failure(
1811 model.upstream_provider_name(),
1812 code,
1813 message,
1814 retry_after.map(Duration::from_secs_f64),
1815 ),
1816 );
1817 }
1818 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1819 thread.update_model_request_usage(
1820 amount as u32,
1821 limit,
1822 cx,
1823 );
1824 }
1825 CompletionRequestStatus::ToolUseLimitReached => {
1826 thread.tool_use_limit_reached = true;
1827 cx.emit(ThreadEvent::ToolUseLimitReached);
1828 }
1829 }
1830 }
1831 }
1832 }
1833
1834 thread.touch_updated_at();
1835 cx.emit(ThreadEvent::StreamedCompletion);
1836 cx.notify();
1837
1838 thread.auto_capture_telemetry(cx);
1839 Ok(())
1840 })??;
1841
1842 smol::future::yield_now().await;
1843 }
1844
1845 thread.update(cx, |thread, cx| {
1846 thread.last_received_chunk_at = None;
1847 thread
1848 .pending_completions
1849 .retain(|completion| completion.id != pending_completion_id);
1850
1851 // If there is a response without tool use, summarize the message. Otherwise,
1852 // allow two tool uses before summarizing.
1853 if matches!(thread.summary, ThreadSummary::Pending)
1854 && thread.messages.len() >= 2
1855 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1856 {
1857 thread.summarize(cx);
1858 }
1859 })?;
1860
1861 anyhow::Ok(stop_reason)
1862 };
1863
1864 let result = stream_completion.await;
1865 let mut retry_scheduled = false;
1866
1867 thread
1868 .update(cx, |thread, cx| {
1869 thread.finalize_pending_checkpoint(cx);
1870 match result.as_ref() {
1871 Ok(stop_reason) => {
1872 match stop_reason {
1873 StopReason::ToolUse => {
1874 let tool_uses =
1875 thread.use_pending_tools(window, model.clone(), cx);
1876 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1877 }
1878 StopReason::EndTurn | StopReason::MaxTokens => {
1879 thread.project.update(cx, |project, cx| {
1880 project.set_agent_location(None, cx);
1881 });
1882 }
1883 StopReason::Refusal => {
1884 thread.project.update(cx, |project, cx| {
1885 project.set_agent_location(None, cx);
1886 });
1887
1888 // Remove the turn that was refused.
1889 //
1890 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1891 {
1892 let mut messages_to_remove = Vec::new();
1893
1894 for (ix, message) in
1895 thread.messages.iter().enumerate().rev()
1896 {
1897 messages_to_remove.push(message.id);
1898
1899 if message.role == Role::User {
1900 if ix == 0 {
1901 break;
1902 }
1903
1904 if let Some(prev_message) =
1905 thread.messages.get(ix - 1)
1906 {
1907 if prev_message.role == Role::Assistant {
1908 break;
1909 }
1910 }
1911 }
1912 }
1913
1914 for message_id in messages_to_remove {
1915 thread.delete_message(message_id, cx);
1916 }
1917 }
1918
1919 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1920 header: "Language model refusal".into(),
1921 message:
1922 "Model refused to generate content for safety reasons."
1923 .into(),
1924 }));
1925 }
1926 }
1927
1928 // We successfully completed, so cancel any remaining retries.
1929 thread.retry_state = None;
1930 }
1931 Err(error) => {
1932 thread.project.update(cx, |project, cx| {
1933 project.set_agent_location(None, cx);
1934 });
1935
1936 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1937 let error_message = error
1938 .chain()
1939 .map(|err| err.to_string())
1940 .collect::<Vec<_>>()
1941 .join("\n");
1942 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1943 header: "Error interacting with language model".into(),
1944 message: SharedString::from(error_message.clone()),
1945 }));
1946 }
1947
1948 if error.is::<PaymentRequiredError>() {
1949 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1950 } else if let Some(error) =
1951 error.downcast_ref::<ModelRequestLimitReachedError>()
1952 {
1953 cx.emit(ThreadEvent::ShowError(
1954 ThreadError::ModelRequestLimitReached { plan: error.plan },
1955 ));
1956 } else if let Some(completion_error) =
1957 error.downcast_ref::<LanguageModelCompletionError>()
1958 {
1959 use LanguageModelCompletionError::*;
1960 match &completion_error {
1961 PromptTooLarge { tokens, .. } => {
1962 let tokens = tokens.unwrap_or_else(|| {
1963 // We didn't get an exact token count from the API, so fall back on our estimate.
1964 thread
1965 .total_token_usage()
1966 .map(|usage| usage.total)
1967 .unwrap_or(0)
1968 // We know the context window was exceeded in practice, so if our estimate was
1969 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1970 .max(
1971 model
1972 .max_token_count_for_mode(completion_mode)
1973 .saturating_add(1),
1974 )
1975 });
1976 thread.exceeded_window_error = Some(ExceededWindowError {
1977 model_id: model.id(),
1978 token_count: tokens,
1979 });
1980 cx.notify();
1981 }
1982 RateLimitExceeded {
1983 retry_after: Some(retry_after),
1984 ..
1985 }
1986 | ServerOverloaded {
1987 retry_after: Some(retry_after),
1988 ..
1989 } => {
1990 thread.handle_rate_limit_error(
1991 &completion_error,
1992 *retry_after,
1993 model.clone(),
1994 intent,
1995 window,
1996 cx,
1997 );
1998 retry_scheduled = true;
1999 }
2000 RateLimitExceeded { .. } | ServerOverloaded { .. } => {
2001 retry_scheduled = thread.handle_retryable_error(
2002 &completion_error,
2003 model.clone(),
2004 intent,
2005 window,
2006 cx,
2007 );
2008 if !retry_scheduled {
2009 emit_generic_error(error, cx);
2010 }
2011 }
2012 ApiInternalServerError { .. }
2013 | ApiReadResponseError { .. }
2014 | HttpSend { .. } => {
2015 retry_scheduled = thread.handle_retryable_error(
2016 &completion_error,
2017 model.clone(),
2018 intent,
2019 window,
2020 cx,
2021 );
2022 if !retry_scheduled {
2023 emit_generic_error(error, cx);
2024 }
2025 }
2026 NoApiKey { .. }
2027 | HttpResponseError { .. }
2028 | BadRequestFormat { .. }
2029 | AuthenticationError { .. }
2030 | PermissionError { .. }
2031 | ApiEndpointNotFound { .. }
2032 | SerializeRequest { .. }
2033 | BuildRequestBody { .. }
2034 | DeserializeResponse { .. }
2035 | Other { .. } => emit_generic_error(error, cx),
2036 }
2037 } else {
2038 emit_generic_error(error, cx);
2039 }
2040
2041 if !retry_scheduled {
2042 thread.cancel_last_completion(window, cx);
2043 }
2044 }
2045 }
2046
2047 if !retry_scheduled {
2048 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2049 }
2050
2051 if let Some((request_callback, (request, response_events))) = thread
2052 .request_callback
2053 .as_mut()
2054 .zip(request_callback_parameters.as_ref())
2055 {
2056 request_callback(request, response_events);
2057 }
2058
2059 thread.auto_capture_telemetry(cx);
2060
2061 if let Ok(initial_usage) = initial_token_usage {
2062 let usage = thread.cumulative_token_usage - initial_usage;
2063
2064 telemetry::event!(
2065 "Assistant Thread Completion",
2066 thread_id = thread.id().to_string(),
2067 prompt_id = prompt_id,
2068 model = model.telemetry_id(),
2069 model_provider = model.provider_id().to_string(),
2070 input_tokens = usage.input_tokens,
2071 output_tokens = usage.output_tokens,
2072 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2073 cache_read_input_tokens = usage.cache_read_input_tokens,
2074 );
2075 }
2076 })
2077 .ok();
2078 });
2079
2080 self.pending_completions.push(PendingCompletion {
2081 id: pending_completion_id,
2082 queue_state: QueueState::Sending,
2083 _task: task,
2084 });
2085 }
2086
2087 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2088 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2089 println!("No thread summary model");
2090 return;
2091 };
2092
2093 if !model.provider.is_authenticated(cx) {
2094 return;
2095 }
2096
2097 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2098
2099 let request = self.to_summarize_request(
2100 &model.model,
2101 CompletionIntent::ThreadSummarization,
2102 added_user_message.into(),
2103 cx,
2104 );
2105
2106 self.summary = ThreadSummary::Generating;
2107
2108 self.pending_summary = cx.spawn(async move |this, cx| {
2109 let result = async {
2110 let mut messages = model.model.stream_completion(request, &cx).await?;
2111
2112 let mut new_summary = String::new();
2113 while let Some(event) = messages.next().await {
2114 let Ok(event) = event else {
2115 continue;
2116 };
2117 let text = match event {
2118 LanguageModelCompletionEvent::Text(text) => text,
2119 LanguageModelCompletionEvent::StatusUpdate(
2120 CompletionRequestStatus::UsageUpdated { amount, limit },
2121 ) => {
2122 this.update(cx, |thread, cx| {
2123 thread.update_model_request_usage(amount as u32, limit, cx);
2124 })?;
2125 continue;
2126 }
2127 _ => continue,
2128 };
2129
2130 let mut lines = text.lines();
2131 new_summary.extend(lines.next());
2132
2133 // Stop if the LLM generated multiple lines.
2134 if lines.next().is_some() {
2135 break;
2136 }
2137 }
2138
2139 anyhow::Ok(new_summary)
2140 }
2141 .await;
2142
2143 this.update(cx, |this, cx| {
2144 match result {
2145 Ok(new_summary) => {
2146 if new_summary.is_empty() {
2147 this.summary = ThreadSummary::Error;
2148 } else {
2149 this.summary = ThreadSummary::Ready(new_summary.into());
2150 }
2151 }
2152 Err(err) => {
2153 this.summary = ThreadSummary::Error;
2154 log::error!("Failed to generate thread summary: {}", err);
2155 }
2156 }
2157 cx.emit(ThreadEvent::SummaryGenerated);
2158 })
2159 .log_err()?;
2160
2161 Some(())
2162 });
2163 }
2164
2165 fn handle_rate_limit_error(
2166 &mut self,
2167 error: &LanguageModelCompletionError,
2168 retry_after: Duration,
2169 model: Arc<dyn LanguageModel>,
2170 intent: CompletionIntent,
2171 window: Option<AnyWindowHandle>,
2172 cx: &mut Context<Self>,
2173 ) {
2174 // For rate limit errors, we only retry once with the specified duration
2175 let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2176 log::warn!(
2177 "Retrying completion request in {} seconds: {error:?}",
2178 retry_after.as_secs(),
2179 );
2180
2181 // Add a UI-only message instead of a regular message
2182 let id = self.next_message_id.post_inc();
2183 self.messages.push(Message {
2184 id,
2185 role: Role::System,
2186 segments: vec![MessageSegment::Text(retry_message)],
2187 loaded_context: LoadedContext::default(),
2188 creases: Vec::new(),
2189 is_hidden: false,
2190 ui_only: true,
2191 });
2192 cx.emit(ThreadEvent::MessageAdded(id));
2193 // Schedule the retry
2194 let thread_handle = cx.entity().downgrade();
2195
2196 cx.spawn(async move |_thread, cx| {
2197 cx.background_executor().timer(retry_after).await;
2198
2199 thread_handle
2200 .update(cx, |thread, cx| {
2201 // Retry the completion
2202 thread.send_to_model(model, intent, window, cx);
2203 })
2204 .log_err();
2205 })
2206 .detach();
2207 }
2208
2209 fn handle_retryable_error(
2210 &mut self,
2211 error: &LanguageModelCompletionError,
2212 model: Arc<dyn LanguageModel>,
2213 intent: CompletionIntent,
2214 window: Option<AnyWindowHandle>,
2215 cx: &mut Context<Self>,
2216 ) -> bool {
2217 self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2218 }
2219
2220 fn handle_retryable_error_with_delay(
2221 &mut self,
2222 error: &LanguageModelCompletionError,
2223 custom_delay: Option<Duration>,
2224 model: Arc<dyn LanguageModel>,
2225 intent: CompletionIntent,
2226 window: Option<AnyWindowHandle>,
2227 cx: &mut Context<Self>,
2228 ) -> bool {
2229 let retry_state = self.retry_state.get_or_insert(RetryState {
2230 attempt: 0,
2231 max_attempts: MAX_RETRY_ATTEMPTS,
2232 intent,
2233 });
2234
2235 retry_state.attempt += 1;
2236 let attempt = retry_state.attempt;
2237 let max_attempts = retry_state.max_attempts;
2238 let intent = retry_state.intent;
2239
2240 if attempt <= max_attempts {
2241 // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2242 let delay = if let Some(custom_delay) = custom_delay {
2243 custom_delay
2244 } else {
2245 let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2246 Duration::from_secs(delay_secs)
2247 };
2248
2249 // Add a transient message to inform the user
2250 let delay_secs = delay.as_secs();
2251 let retry_message = format!(
2252 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2253 in {delay_secs} seconds..."
2254 );
2255 log::warn!(
2256 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2257 in {delay_secs} seconds: {error:?}",
2258 );
2259
2260 // Add a UI-only message instead of a regular message
2261 let id = self.next_message_id.post_inc();
2262 self.messages.push(Message {
2263 id,
2264 role: Role::System,
2265 segments: vec![MessageSegment::Text(retry_message)],
2266 loaded_context: LoadedContext::default(),
2267 creases: Vec::new(),
2268 is_hidden: false,
2269 ui_only: true,
2270 });
2271 cx.emit(ThreadEvent::MessageAdded(id));
2272
2273 // Schedule the retry
2274 let thread_handle = cx.entity().downgrade();
2275
2276 cx.spawn(async move |_thread, cx| {
2277 cx.background_executor().timer(delay).await;
2278
2279 thread_handle
2280 .update(cx, |thread, cx| {
2281 // Retry the completion
2282 thread.send_to_model(model, intent, window, cx);
2283 })
2284 .log_err();
2285 })
2286 .detach();
2287
2288 true
2289 } else {
2290 // Max retries exceeded
2291 self.retry_state = None;
2292
2293 let notification_text = if max_attempts == 1 {
2294 "Failed after retrying.".into()
2295 } else {
2296 format!("Failed after retrying {} times.", max_attempts).into()
2297 };
2298
2299 // Stop generating since we're giving up on retrying.
2300 self.pending_completions.clear();
2301
2302 cx.emit(ThreadEvent::RetriesFailed {
2303 message: notification_text,
2304 });
2305
2306 false
2307 }
2308 }
2309
2310 pub fn start_generating_detailed_summary_if_needed(
2311 &mut self,
2312 thread_store: WeakEntity<ThreadStore>,
2313 cx: &mut Context<Self>,
2314 ) {
2315 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2316 return;
2317 };
2318
2319 match &*self.detailed_summary_rx.borrow() {
2320 DetailedSummaryState::Generating { message_id, .. }
2321 | DetailedSummaryState::Generated { message_id, .. }
2322 if *message_id == last_message_id =>
2323 {
2324 // Already up-to-date
2325 return;
2326 }
2327 _ => {}
2328 }
2329
2330 let Some(ConfiguredModel { model, provider }) =
2331 LanguageModelRegistry::read_global(cx).thread_summary_model()
2332 else {
2333 return;
2334 };
2335
2336 if !provider.is_authenticated(cx) {
2337 return;
2338 }
2339
2340 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2341
2342 let request = self.to_summarize_request(
2343 &model,
2344 CompletionIntent::ThreadContextSummarization,
2345 added_user_message.into(),
2346 cx,
2347 );
2348
2349 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2350 message_id: last_message_id,
2351 };
2352
2353 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2354 // be better to allow the old task to complete, but this would require logic for choosing
2355 // which result to prefer (the old task could complete after the new one, resulting in a
2356 // stale summary).
2357 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2358 let stream = model.stream_completion_text(request, &cx);
2359 let Some(mut messages) = stream.await.log_err() else {
2360 thread
2361 .update(cx, |thread, _cx| {
2362 *thread.detailed_summary_tx.borrow_mut() =
2363 DetailedSummaryState::NotGenerated;
2364 })
2365 .ok()?;
2366 return None;
2367 };
2368
2369 let mut new_detailed_summary = String::new();
2370
2371 while let Some(chunk) = messages.stream.next().await {
2372 if let Some(chunk) = chunk.log_err() {
2373 new_detailed_summary.push_str(&chunk);
2374 }
2375 }
2376
2377 thread
2378 .update(cx, |thread, _cx| {
2379 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2380 text: new_detailed_summary.into(),
2381 message_id: last_message_id,
2382 };
2383 })
2384 .ok()?;
2385
2386 // Save thread so its summary can be reused later
2387 if let Some(thread) = thread.upgrade() {
2388 if let Ok(Ok(save_task)) = cx.update(|cx| {
2389 thread_store
2390 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2391 }) {
2392 save_task.await.log_err();
2393 }
2394 }
2395
2396 Some(())
2397 });
2398 }
2399
2400 pub async fn wait_for_detailed_summary_or_text(
2401 this: &Entity<Self>,
2402 cx: &mut AsyncApp,
2403 ) -> Option<SharedString> {
2404 let mut detailed_summary_rx = this
2405 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2406 .ok()?;
2407 loop {
2408 match detailed_summary_rx.recv().await? {
2409 DetailedSummaryState::Generating { .. } => {}
2410 DetailedSummaryState::NotGenerated => {
2411 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2412 }
2413 DetailedSummaryState::Generated { text, .. } => return Some(text),
2414 }
2415 }
2416 }
2417
2418 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2419 self.detailed_summary_rx
2420 .borrow()
2421 .text()
2422 .unwrap_or_else(|| self.text().into())
2423 }
2424
2425 pub fn is_generating_detailed_summary(&self) -> bool {
2426 matches!(
2427 &*self.detailed_summary_rx.borrow(),
2428 DetailedSummaryState::Generating { .. }
2429 )
2430 }
2431
2432 pub fn use_pending_tools(
2433 &mut self,
2434 window: Option<AnyWindowHandle>,
2435 model: Arc<dyn LanguageModel>,
2436 cx: &mut Context<Self>,
2437 ) -> Vec<PendingToolUse> {
2438 self.auto_capture_telemetry(cx);
2439 let request =
2440 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2441 let pending_tool_uses = self
2442 .tool_use
2443 .pending_tool_uses()
2444 .into_iter()
2445 .filter(|tool_use| tool_use.status.is_idle())
2446 .cloned()
2447 .collect::<Vec<_>>();
2448
2449 for tool_use in pending_tool_uses.iter() {
2450 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2451 }
2452
2453 pending_tool_uses
2454 }
2455
2456 fn use_pending_tool(
2457 &mut self,
2458 tool_use: PendingToolUse,
2459 request: Arc<LanguageModelRequest>,
2460 model: Arc<dyn LanguageModel>,
2461 window: Option<AnyWindowHandle>,
2462 cx: &mut Context<Self>,
2463 ) {
2464 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2465 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2466 };
2467
2468 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2469 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2470 }
2471
2472 if tool.needs_confirmation(&tool_use.input, cx)
2473 && !AgentSettings::get_global(cx).always_allow_tool_actions
2474 {
2475 self.tool_use.confirm_tool_use(
2476 tool_use.id,
2477 tool_use.ui_text,
2478 tool_use.input,
2479 request,
2480 tool,
2481 );
2482 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2483 } else {
2484 self.run_tool(
2485 tool_use.id,
2486 tool_use.ui_text,
2487 tool_use.input,
2488 request,
2489 tool,
2490 model,
2491 window,
2492 cx,
2493 );
2494 }
2495 }
2496
2497 pub fn handle_hallucinated_tool_use(
2498 &mut self,
2499 tool_use_id: LanguageModelToolUseId,
2500 hallucinated_tool_name: Arc<str>,
2501 window: Option<AnyWindowHandle>,
2502 cx: &mut Context<Thread>,
2503 ) {
2504 let available_tools = self.profile.enabled_tools(cx);
2505
2506 let tool_list = available_tools
2507 .iter()
2508 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2509 .collect::<Vec<_>>()
2510 .join("\n");
2511
2512 let error_message = format!(
2513 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2514 hallucinated_tool_name, tool_list
2515 );
2516
2517 let pending_tool_use = self.tool_use.insert_tool_output(
2518 tool_use_id.clone(),
2519 hallucinated_tool_name,
2520 Err(anyhow!("Missing tool call: {error_message}")),
2521 self.configured_model.as_ref(),
2522 self.completion_mode,
2523 );
2524
2525 cx.emit(ThreadEvent::MissingToolUse {
2526 tool_use_id: tool_use_id.clone(),
2527 ui_text: error_message.into(),
2528 });
2529
2530 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2531 }
2532
2533 pub fn receive_invalid_tool_json(
2534 &mut self,
2535 tool_use_id: LanguageModelToolUseId,
2536 tool_name: Arc<str>,
2537 invalid_json: Arc<str>,
2538 error: String,
2539 window: Option<AnyWindowHandle>,
2540 cx: &mut Context<Thread>,
2541 ) {
2542 log::error!("The model returned invalid input JSON: {invalid_json}");
2543
2544 let pending_tool_use = self.tool_use.insert_tool_output(
2545 tool_use_id.clone(),
2546 tool_name,
2547 Err(anyhow!("Error parsing input JSON: {error}")),
2548 self.configured_model.as_ref(),
2549 self.completion_mode,
2550 );
2551 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2552 pending_tool_use.ui_text.clone()
2553 } else {
2554 log::error!(
2555 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2556 );
2557 format!("Unknown tool {}", tool_use_id).into()
2558 };
2559
2560 cx.emit(ThreadEvent::InvalidToolInput {
2561 tool_use_id: tool_use_id.clone(),
2562 ui_text,
2563 invalid_input_json: invalid_json,
2564 });
2565
2566 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2567 }
2568
2569 pub fn run_tool(
2570 &mut self,
2571 tool_use_id: LanguageModelToolUseId,
2572 ui_text: impl Into<SharedString>,
2573 input: serde_json::Value,
2574 request: Arc<LanguageModelRequest>,
2575 tool: Arc<dyn Tool>,
2576 model: Arc<dyn LanguageModel>,
2577 window: Option<AnyWindowHandle>,
2578 cx: &mut Context<Thread>,
2579 ) {
2580 let task =
2581 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2582 self.tool_use
2583 .run_pending_tool(tool_use_id, ui_text.into(), task);
2584 }
2585
2586 fn spawn_tool_use(
2587 &mut self,
2588 tool_use_id: LanguageModelToolUseId,
2589 request: Arc<LanguageModelRequest>,
2590 input: serde_json::Value,
2591 tool: Arc<dyn Tool>,
2592 model: Arc<dyn LanguageModel>,
2593 window: Option<AnyWindowHandle>,
2594 cx: &mut Context<Thread>,
2595 ) -> Task<()> {
2596 let tool_name: Arc<str> = tool.name().into();
2597
2598 let tool_result = tool.run(
2599 input,
2600 request,
2601 self.project.clone(),
2602 self.action_log.clone(),
2603 model,
2604 window,
2605 cx,
2606 );
2607
2608 // Store the card separately if it exists
2609 if let Some(card) = tool_result.card.clone() {
2610 self.tool_use
2611 .insert_tool_result_card(tool_use_id.clone(), card);
2612 }
2613
2614 cx.spawn({
2615 async move |thread: WeakEntity<Thread>, cx| {
2616 let output = tool_result.output.await;
2617
2618 thread
2619 .update(cx, |thread, cx| {
2620 let pending_tool_use = thread.tool_use.insert_tool_output(
2621 tool_use_id.clone(),
2622 tool_name,
2623 output,
2624 thread.configured_model.as_ref(),
2625 thread.completion_mode,
2626 );
2627 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2628 })
2629 .ok();
2630 }
2631 })
2632 }
2633
2634 fn tool_finished(
2635 &mut self,
2636 tool_use_id: LanguageModelToolUseId,
2637 pending_tool_use: Option<PendingToolUse>,
2638 canceled: bool,
2639 window: Option<AnyWindowHandle>,
2640 cx: &mut Context<Self>,
2641 ) {
2642 if self.all_tools_finished() {
2643 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2644 if !canceled {
2645 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2646 }
2647 self.auto_capture_telemetry(cx);
2648 }
2649 }
2650
2651 cx.emit(ThreadEvent::ToolFinished {
2652 tool_use_id,
2653 pending_tool_use,
2654 });
2655 }
2656
2657 /// Cancels the last pending completion, if there are any pending.
2658 ///
2659 /// Returns whether a completion was canceled.
2660 pub fn cancel_last_completion(
2661 &mut self,
2662 window: Option<AnyWindowHandle>,
2663 cx: &mut Context<Self>,
2664 ) -> bool {
2665 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2666
2667 self.retry_state = None;
2668
2669 for pending_tool_use in self.tool_use.cancel_pending() {
2670 canceled = true;
2671 self.tool_finished(
2672 pending_tool_use.id.clone(),
2673 Some(pending_tool_use),
2674 true,
2675 window,
2676 cx,
2677 );
2678 }
2679
2680 if canceled {
2681 cx.emit(ThreadEvent::CompletionCanceled);
2682
2683 // When canceled, we always want to insert the checkpoint.
2684 // (We skip over finalize_pending_checkpoint, because it
2685 // would conclude we didn't have anything to insert here.)
2686 if let Some(checkpoint) = self.pending_checkpoint.take() {
2687 self.insert_checkpoint(checkpoint, cx);
2688 }
2689 } else {
2690 self.finalize_pending_checkpoint(cx);
2691 }
2692
2693 canceled
2694 }
2695
2696 /// Signals that any in-progress editing should be canceled.
2697 ///
2698 /// This method is used to notify listeners (like ActiveThread) that
2699 /// they should cancel any editing operations.
2700 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2701 cx.emit(ThreadEvent::CancelEditing);
2702 }
2703
2704 pub fn feedback(&self) -> Option<ThreadFeedback> {
2705 self.feedback
2706 }
2707
2708 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2709 self.message_feedback.get(&message_id).copied()
2710 }
2711
2712 pub fn report_message_feedback(
2713 &mut self,
2714 message_id: MessageId,
2715 feedback: ThreadFeedback,
2716 cx: &mut Context<Self>,
2717 ) -> Task<Result<()>> {
2718 if self.message_feedback.get(&message_id) == Some(&feedback) {
2719 return Task::ready(Ok(()));
2720 }
2721
2722 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2723 let serialized_thread = self.serialize(cx);
2724 let thread_id = self.id().clone();
2725 let client = self.project.read(cx).client();
2726
2727 let enabled_tool_names: Vec<String> = self
2728 .profile
2729 .enabled_tools(cx)
2730 .iter()
2731 .map(|(name, _)| name.clone().into())
2732 .collect();
2733
2734 self.message_feedback.insert(message_id, feedback);
2735
2736 cx.notify();
2737
2738 let message_content = self
2739 .message(message_id)
2740 .map(|msg| msg.to_string())
2741 .unwrap_or_default();
2742
2743 cx.background_spawn(async move {
2744 let final_project_snapshot = final_project_snapshot.await;
2745 let serialized_thread = serialized_thread.await?;
2746 let thread_data =
2747 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2748
2749 let rating = match feedback {
2750 ThreadFeedback::Positive => "positive",
2751 ThreadFeedback::Negative => "negative",
2752 };
2753 telemetry::event!(
2754 "Assistant Thread Rated",
2755 rating,
2756 thread_id,
2757 enabled_tool_names,
2758 message_id = message_id.0,
2759 message_content,
2760 thread_data,
2761 final_project_snapshot
2762 );
2763 client.telemetry().flush_events().await;
2764
2765 Ok(())
2766 })
2767 }
2768
2769 pub fn report_feedback(
2770 &mut self,
2771 feedback: ThreadFeedback,
2772 cx: &mut Context<Self>,
2773 ) -> Task<Result<()>> {
2774 let last_assistant_message_id = self
2775 .messages
2776 .iter()
2777 .rev()
2778 .find(|msg| msg.role == Role::Assistant)
2779 .map(|msg| msg.id);
2780
2781 if let Some(message_id) = last_assistant_message_id {
2782 self.report_message_feedback(message_id, feedback, cx)
2783 } else {
2784 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2785 let serialized_thread = self.serialize(cx);
2786 let thread_id = self.id().clone();
2787 let client = self.project.read(cx).client();
2788 self.feedback = Some(feedback);
2789 cx.notify();
2790
2791 cx.background_spawn(async move {
2792 let final_project_snapshot = final_project_snapshot.await;
2793 let serialized_thread = serialized_thread.await?;
2794 let thread_data = serde_json::to_value(serialized_thread)
2795 .unwrap_or_else(|_| serde_json::Value::Null);
2796
2797 let rating = match feedback {
2798 ThreadFeedback::Positive => "positive",
2799 ThreadFeedback::Negative => "negative",
2800 };
2801 telemetry::event!(
2802 "Assistant Thread Rated",
2803 rating,
2804 thread_id,
2805 thread_data,
2806 final_project_snapshot
2807 );
2808 client.telemetry().flush_events().await;
2809
2810 Ok(())
2811 })
2812 }
2813 }
2814
2815 /// Create a snapshot of the current project state including git information and unsaved buffers.
2816 fn project_snapshot(
2817 project: Entity<Project>,
2818 cx: &mut Context<Self>,
2819 ) -> Task<Arc<ProjectSnapshot>> {
2820 let git_store = project.read(cx).git_store().clone();
2821 let worktree_snapshots: Vec<_> = project
2822 .read(cx)
2823 .visible_worktrees(cx)
2824 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2825 .collect();
2826
2827 cx.spawn(async move |_, cx| {
2828 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2829
2830 let mut unsaved_buffers = Vec::new();
2831 cx.update(|app_cx| {
2832 let buffer_store = project.read(app_cx).buffer_store();
2833 for buffer_handle in buffer_store.read(app_cx).buffers() {
2834 let buffer = buffer_handle.read(app_cx);
2835 if buffer.is_dirty() {
2836 if let Some(file) = buffer.file() {
2837 let path = file.path().to_string_lossy().to_string();
2838 unsaved_buffers.push(path);
2839 }
2840 }
2841 }
2842 })
2843 .ok();
2844
2845 Arc::new(ProjectSnapshot {
2846 worktree_snapshots,
2847 unsaved_buffer_paths: unsaved_buffers,
2848 timestamp: Utc::now(),
2849 })
2850 })
2851 }
2852
2853 fn worktree_snapshot(
2854 worktree: Entity<project::Worktree>,
2855 git_store: Entity<GitStore>,
2856 cx: &App,
2857 ) -> Task<WorktreeSnapshot> {
2858 cx.spawn(async move |cx| {
2859 // Get worktree path and snapshot
2860 let worktree_info = cx.update(|app_cx| {
2861 let worktree = worktree.read(app_cx);
2862 let path = worktree.abs_path().to_string_lossy().to_string();
2863 let snapshot = worktree.snapshot();
2864 (path, snapshot)
2865 });
2866
2867 let Ok((worktree_path, _snapshot)) = worktree_info else {
2868 return WorktreeSnapshot {
2869 worktree_path: String::new(),
2870 git_state: None,
2871 };
2872 };
2873
2874 let git_state = git_store
2875 .update(cx, |git_store, cx| {
2876 git_store
2877 .repositories()
2878 .values()
2879 .find(|repo| {
2880 repo.read(cx)
2881 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2882 .is_some()
2883 })
2884 .cloned()
2885 })
2886 .ok()
2887 .flatten()
2888 .map(|repo| {
2889 repo.update(cx, |repo, _| {
2890 let current_branch =
2891 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2892 repo.send_job(None, |state, _| async move {
2893 let RepositoryState::Local { backend, .. } = state else {
2894 return GitState {
2895 remote_url: None,
2896 head_sha: None,
2897 current_branch,
2898 diff: None,
2899 };
2900 };
2901
2902 let remote_url = backend.remote_url("origin");
2903 let head_sha = backend.head_sha().await;
2904 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2905
2906 GitState {
2907 remote_url,
2908 head_sha,
2909 current_branch,
2910 diff,
2911 }
2912 })
2913 })
2914 });
2915
2916 let git_state = match git_state {
2917 Some(git_state) => match git_state.ok() {
2918 Some(git_state) => git_state.await.ok(),
2919 None => None,
2920 },
2921 None => None,
2922 };
2923
2924 WorktreeSnapshot {
2925 worktree_path,
2926 git_state,
2927 }
2928 })
2929 }
2930
2931 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2932 let mut markdown = Vec::new();
2933
2934 let summary = self.summary().or_default();
2935 writeln!(markdown, "# {summary}\n")?;
2936
2937 for message in self.messages() {
2938 writeln!(
2939 markdown,
2940 "## {role}\n",
2941 role = match message.role {
2942 Role::User => "User",
2943 Role::Assistant => "Agent",
2944 Role::System => "System",
2945 }
2946 )?;
2947
2948 if !message.loaded_context.text.is_empty() {
2949 writeln!(markdown, "{}", message.loaded_context.text)?;
2950 }
2951
2952 if !message.loaded_context.images.is_empty() {
2953 writeln!(
2954 markdown,
2955 "\n{} images attached as context.\n",
2956 message.loaded_context.images.len()
2957 )?;
2958 }
2959
2960 for segment in &message.segments {
2961 match segment {
2962 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2963 MessageSegment::Thinking { text, .. } => {
2964 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2965 }
2966 MessageSegment::RedactedThinking(_) => {}
2967 }
2968 }
2969
2970 for tool_use in self.tool_uses_for_message(message.id, cx) {
2971 writeln!(
2972 markdown,
2973 "**Use Tool: {} ({})**",
2974 tool_use.name, tool_use.id
2975 )?;
2976 writeln!(markdown, "```json")?;
2977 writeln!(
2978 markdown,
2979 "{}",
2980 serde_json::to_string_pretty(&tool_use.input)?
2981 )?;
2982 writeln!(markdown, "```")?;
2983 }
2984
2985 for tool_result in self.tool_results_for_message(message.id) {
2986 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2987 if tool_result.is_error {
2988 write!(markdown, " (Error)")?;
2989 }
2990
2991 writeln!(markdown, "**\n")?;
2992 match &tool_result.content {
2993 LanguageModelToolResultContent::Text(text) => {
2994 writeln!(markdown, "{text}")?;
2995 }
2996 LanguageModelToolResultContent::Image(image) => {
2997 writeln!(markdown, "", image.source)?;
2998 }
2999 }
3000
3001 if let Some(output) = tool_result.output.as_ref() {
3002 writeln!(
3003 markdown,
3004 "\n\nDebug Output:\n\n```json\n{}\n```\n",
3005 serde_json::to_string_pretty(output)?
3006 )?;
3007 }
3008 }
3009 }
3010
3011 Ok(String::from_utf8_lossy(&markdown).to_string())
3012 }
3013
3014 pub fn keep_edits_in_range(
3015 &mut self,
3016 buffer: Entity<language::Buffer>,
3017 buffer_range: Range<language::Anchor>,
3018 cx: &mut Context<Self>,
3019 ) {
3020 self.action_log.update(cx, |action_log, cx| {
3021 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3022 });
3023 }
3024
3025 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3026 self.action_log
3027 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3028 }
3029
3030 pub fn reject_edits_in_ranges(
3031 &mut self,
3032 buffer: Entity<language::Buffer>,
3033 buffer_ranges: Vec<Range<language::Anchor>>,
3034 cx: &mut Context<Self>,
3035 ) -> Task<Result<()>> {
3036 self.action_log.update(cx, |action_log, cx| {
3037 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3038 })
3039 }
3040
3041 pub fn action_log(&self) -> &Entity<ActionLog> {
3042 &self.action_log
3043 }
3044
3045 pub fn project(&self) -> &Entity<Project> {
3046 &self.project
3047 }
3048
3049 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3050 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3051 return;
3052 }
3053
3054 let now = Instant::now();
3055 if let Some(last) = self.last_auto_capture_at {
3056 if now.duration_since(last).as_secs() < 10 {
3057 return;
3058 }
3059 }
3060
3061 self.last_auto_capture_at = Some(now);
3062
3063 let thread_id = self.id().clone();
3064 let github_login = self
3065 .project
3066 .read(cx)
3067 .user_store()
3068 .read(cx)
3069 .current_user()
3070 .map(|user| user.github_login.clone());
3071 let client = self.project.read(cx).client();
3072 let serialize_task = self.serialize(cx);
3073
3074 cx.background_executor()
3075 .spawn(async move {
3076 if let Ok(serialized_thread) = serialize_task.await {
3077 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3078 telemetry::event!(
3079 "Agent Thread Auto-Captured",
3080 thread_id = thread_id.to_string(),
3081 thread_data = thread_data,
3082 auto_capture_reason = "tracked_user",
3083 github_login = github_login
3084 );
3085
3086 client.telemetry().flush_events().await;
3087 }
3088 }
3089 })
3090 .detach();
3091 }
3092
3093 pub fn cumulative_token_usage(&self) -> TokenUsage {
3094 self.cumulative_token_usage
3095 }
3096
3097 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3098 let Some(model) = self.configured_model.as_ref() else {
3099 return TotalTokenUsage::default();
3100 };
3101
3102 let max = model
3103 .model
3104 .max_token_count_for_mode(self.completion_mode().into());
3105
3106 let index = self
3107 .messages
3108 .iter()
3109 .position(|msg| msg.id == message_id)
3110 .unwrap_or(0);
3111
3112 if index == 0 {
3113 return TotalTokenUsage { total: 0, max };
3114 }
3115
3116 let token_usage = &self
3117 .request_token_usage
3118 .get(index - 1)
3119 .cloned()
3120 .unwrap_or_default();
3121
3122 TotalTokenUsage {
3123 total: token_usage.total_tokens(),
3124 max,
3125 }
3126 }
3127
3128 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3129 let model = self.configured_model.as_ref()?;
3130
3131 let max = model
3132 .model
3133 .max_token_count_for_mode(self.completion_mode().into());
3134
3135 if let Some(exceeded_error) = &self.exceeded_window_error {
3136 if model.model.id() == exceeded_error.model_id {
3137 return Some(TotalTokenUsage {
3138 total: exceeded_error.token_count,
3139 max,
3140 });
3141 }
3142 }
3143
3144 let total = self
3145 .token_usage_at_last_message()
3146 .unwrap_or_default()
3147 .total_tokens();
3148
3149 Some(TotalTokenUsage { total, max })
3150 }
3151
3152 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3153 self.request_token_usage
3154 .get(self.messages.len().saturating_sub(1))
3155 .or_else(|| self.request_token_usage.last())
3156 .cloned()
3157 }
3158
3159 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3160 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3161 self.request_token_usage
3162 .resize(self.messages.len(), placeholder);
3163
3164 if let Some(last) = self.request_token_usage.last_mut() {
3165 *last = token_usage;
3166 }
3167 }
3168
3169 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3170 self.project.update(cx, |project, cx| {
3171 project.user_store().update(cx, |user_store, cx| {
3172 user_store.update_model_request_usage(
3173 ModelRequestUsage(RequestUsage {
3174 amount: amount as i32,
3175 limit,
3176 }),
3177 cx,
3178 )
3179 })
3180 });
3181 }
3182
3183 pub fn deny_tool_use(
3184 &mut self,
3185 tool_use_id: LanguageModelToolUseId,
3186 tool_name: Arc<str>,
3187 window: Option<AnyWindowHandle>,
3188 cx: &mut Context<Self>,
3189 ) {
3190 let err = Err(anyhow::anyhow!(
3191 "Permission to run tool action denied by user"
3192 ));
3193
3194 self.tool_use.insert_tool_output(
3195 tool_use_id.clone(),
3196 tool_name,
3197 err,
3198 self.configured_model.as_ref(),
3199 self.completion_mode,
3200 );
3201 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3202 }
3203}
3204
3205#[derive(Debug, Clone, Error)]
3206pub enum ThreadError {
3207 #[error("Payment required")]
3208 PaymentRequired,
3209 #[error("Model request limit reached")]
3210 ModelRequestLimitReached { plan: Plan },
3211 #[error("Message {header}: {message}")]
3212 Message {
3213 header: SharedString,
3214 message: SharedString,
3215 },
3216}
3217
3218#[derive(Debug, Clone)]
3219pub enum ThreadEvent {
3220 ShowError(ThreadError),
3221 StreamedCompletion,
3222 ReceivedTextChunk,
3223 NewRequest,
3224 StreamedAssistantText(MessageId, String),
3225 StreamedAssistantThinking(MessageId, String),
3226 StreamedToolUse {
3227 tool_use_id: LanguageModelToolUseId,
3228 ui_text: Arc<str>,
3229 input: serde_json::Value,
3230 },
3231 MissingToolUse {
3232 tool_use_id: LanguageModelToolUseId,
3233 ui_text: Arc<str>,
3234 },
3235 InvalidToolInput {
3236 tool_use_id: LanguageModelToolUseId,
3237 ui_text: Arc<str>,
3238 invalid_input_json: Arc<str>,
3239 },
3240 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3241 MessageAdded(MessageId),
3242 MessageEdited(MessageId),
3243 MessageDeleted(MessageId),
3244 SummaryGenerated,
3245 SummaryChanged,
3246 UsePendingTools {
3247 tool_uses: Vec<PendingToolUse>,
3248 },
3249 ToolFinished {
3250 #[allow(unused)]
3251 tool_use_id: LanguageModelToolUseId,
3252 /// The pending tool use that corresponds to this tool.
3253 pending_tool_use: Option<PendingToolUse>,
3254 },
3255 CheckpointChanged,
3256 ToolConfirmationNeeded,
3257 ToolUseLimitReached,
3258 CancelEditing,
3259 CompletionCanceled,
3260 ProfileChanged,
3261 RetriesFailed {
3262 message: SharedString,
3263 },
3264}
3265
3266impl EventEmitter<ThreadEvent> for Thread {}
3267
3268struct PendingCompletion {
3269 id: usize,
3270 queue_state: QueueState,
3271 _task: Task<()>,
3272}
3273
3274#[cfg(test)]
3275mod tests {
3276 use super::*;
3277 use crate::{
3278 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3279 };
3280
3281 // Test-specific constants
3282 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3283 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3284 use assistant_tool::ToolRegistry;
3285 use assistant_tools;
3286 use futures::StreamExt;
3287 use futures::future::BoxFuture;
3288 use futures::stream::BoxStream;
3289 use gpui::TestAppContext;
3290 use http_client;
3291 use indoc::indoc;
3292 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3293 use language_model::{
3294 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3295 LanguageModelProviderName, LanguageModelToolChoice,
3296 };
3297 use parking_lot::Mutex;
3298 use project::{FakeFs, Project};
3299 use prompt_store::PromptBuilder;
3300 use serde_json::json;
3301 use settings::{Settings, SettingsStore};
3302 use std::sync::Arc;
3303 use std::time::Duration;
3304 use theme::ThemeSettings;
3305 use util::path;
3306 use workspace::Workspace;
3307
3308 #[gpui::test]
3309 async fn test_message_with_context(cx: &mut TestAppContext) {
3310 init_test_settings(cx);
3311
3312 let project = create_test_project(
3313 cx,
3314 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3315 )
3316 .await;
3317
3318 let (_workspace, _thread_store, thread, context_store, model) =
3319 setup_test_environment(cx, project.clone()).await;
3320
3321 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3322 .await
3323 .unwrap();
3324
3325 let context =
3326 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3327 let loaded_context = cx
3328 .update(|cx| load_context(vec![context], &project, &None, cx))
3329 .await;
3330
3331 // Insert user message with context
3332 let message_id = thread.update(cx, |thread, cx| {
3333 thread.insert_user_message(
3334 "Please explain this code",
3335 loaded_context,
3336 None,
3337 Vec::new(),
3338 cx,
3339 )
3340 });
3341
3342 // Check content and context in message object
3343 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3344
3345 // Use different path format strings based on platform for the test
3346 #[cfg(windows)]
3347 let path_part = r"test\code.rs";
3348 #[cfg(not(windows))]
3349 let path_part = "test/code.rs";
3350
3351 let expected_context = format!(
3352 r#"
3353<context>
3354The following items were attached by the user. They are up-to-date and don't need to be re-read.
3355
3356<files>
3357```rs {path_part}
3358fn main() {{
3359 println!("Hello, world!");
3360}}
3361```
3362</files>
3363</context>
3364"#
3365 );
3366
3367 assert_eq!(message.role, Role::User);
3368 assert_eq!(message.segments.len(), 1);
3369 assert_eq!(
3370 message.segments[0],
3371 MessageSegment::Text("Please explain this code".to_string())
3372 );
3373 assert_eq!(message.loaded_context.text, expected_context);
3374
3375 // Check message in request
3376 let request = thread.update(cx, |thread, cx| {
3377 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3378 });
3379
3380 assert_eq!(request.messages.len(), 2);
3381 let expected_full_message = format!("{}Please explain this code", expected_context);
3382 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3383 }
3384
3385 #[gpui::test]
3386 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3387 init_test_settings(cx);
3388
3389 let project = create_test_project(
3390 cx,
3391 json!({
3392 "file1.rs": "fn function1() {}\n",
3393 "file2.rs": "fn function2() {}\n",
3394 "file3.rs": "fn function3() {}\n",
3395 "file4.rs": "fn function4() {}\n",
3396 }),
3397 )
3398 .await;
3399
3400 let (_, _thread_store, thread, context_store, model) =
3401 setup_test_environment(cx, project.clone()).await;
3402
3403 // First message with context 1
3404 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3405 .await
3406 .unwrap();
3407 let new_contexts = context_store.update(cx, |store, cx| {
3408 store.new_context_for_thread(thread.read(cx), None)
3409 });
3410 assert_eq!(new_contexts.len(), 1);
3411 let loaded_context = cx
3412 .update(|cx| load_context(new_contexts, &project, &None, cx))
3413 .await;
3414 let message1_id = thread.update(cx, |thread, cx| {
3415 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3416 });
3417
3418 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3419 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3420 .await
3421 .unwrap();
3422 let new_contexts = context_store.update(cx, |store, cx| {
3423 store.new_context_for_thread(thread.read(cx), None)
3424 });
3425 assert_eq!(new_contexts.len(), 1);
3426 let loaded_context = cx
3427 .update(|cx| load_context(new_contexts, &project, &None, cx))
3428 .await;
3429 let message2_id = thread.update(cx, |thread, cx| {
3430 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3431 });
3432
3433 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3434 //
3435 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3436 .await
3437 .unwrap();
3438 let new_contexts = context_store.update(cx, |store, cx| {
3439 store.new_context_for_thread(thread.read(cx), None)
3440 });
3441 assert_eq!(new_contexts.len(), 1);
3442 let loaded_context = cx
3443 .update(|cx| load_context(new_contexts, &project, &None, cx))
3444 .await;
3445 let message3_id = thread.update(cx, |thread, cx| {
3446 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3447 });
3448
3449 // Check what contexts are included in each message
3450 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3451 (
3452 thread.message(message1_id).unwrap().clone(),
3453 thread.message(message2_id).unwrap().clone(),
3454 thread.message(message3_id).unwrap().clone(),
3455 )
3456 });
3457
3458 // First message should include context 1
3459 assert!(message1.loaded_context.text.contains("file1.rs"));
3460
3461 // Second message should include only context 2 (not 1)
3462 assert!(!message2.loaded_context.text.contains("file1.rs"));
3463 assert!(message2.loaded_context.text.contains("file2.rs"));
3464
3465 // Third message should include only context 3 (not 1 or 2)
3466 assert!(!message3.loaded_context.text.contains("file1.rs"));
3467 assert!(!message3.loaded_context.text.contains("file2.rs"));
3468 assert!(message3.loaded_context.text.contains("file3.rs"));
3469
3470 // Check entire request to make sure all contexts are properly included
3471 let request = thread.update(cx, |thread, cx| {
3472 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3473 });
3474
3475 // The request should contain all 3 messages
3476 assert_eq!(request.messages.len(), 4);
3477
3478 // Check that the contexts are properly formatted in each message
3479 assert!(request.messages[1].string_contents().contains("file1.rs"));
3480 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3481 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3482
3483 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3484 assert!(request.messages[2].string_contents().contains("file2.rs"));
3485 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3486
3487 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3488 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3489 assert!(request.messages[3].string_contents().contains("file3.rs"));
3490
3491 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3492 .await
3493 .unwrap();
3494 let new_contexts = context_store.update(cx, |store, cx| {
3495 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3496 });
3497 assert_eq!(new_contexts.len(), 3);
3498 let loaded_context = cx
3499 .update(|cx| load_context(new_contexts, &project, &None, cx))
3500 .await
3501 .loaded_context;
3502
3503 assert!(!loaded_context.text.contains("file1.rs"));
3504 assert!(loaded_context.text.contains("file2.rs"));
3505 assert!(loaded_context.text.contains("file3.rs"));
3506 assert!(loaded_context.text.contains("file4.rs"));
3507
3508 let new_contexts = context_store.update(cx, |store, cx| {
3509 // Remove file4.rs
3510 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3511 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3512 });
3513 assert_eq!(new_contexts.len(), 2);
3514 let loaded_context = cx
3515 .update(|cx| load_context(new_contexts, &project, &None, cx))
3516 .await
3517 .loaded_context;
3518
3519 assert!(!loaded_context.text.contains("file1.rs"));
3520 assert!(loaded_context.text.contains("file2.rs"));
3521 assert!(loaded_context.text.contains("file3.rs"));
3522 assert!(!loaded_context.text.contains("file4.rs"));
3523
3524 let new_contexts = context_store.update(cx, |store, cx| {
3525 // Remove file3.rs
3526 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3527 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3528 });
3529 assert_eq!(new_contexts.len(), 1);
3530 let loaded_context = cx
3531 .update(|cx| load_context(new_contexts, &project, &None, cx))
3532 .await
3533 .loaded_context;
3534
3535 assert!(!loaded_context.text.contains("file1.rs"));
3536 assert!(loaded_context.text.contains("file2.rs"));
3537 assert!(!loaded_context.text.contains("file3.rs"));
3538 assert!(!loaded_context.text.contains("file4.rs"));
3539 }
3540
3541 #[gpui::test]
3542 async fn test_message_without_files(cx: &mut TestAppContext) {
3543 init_test_settings(cx);
3544
3545 let project = create_test_project(
3546 cx,
3547 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3548 )
3549 .await;
3550
3551 let (_, _thread_store, thread, _context_store, model) =
3552 setup_test_environment(cx, project.clone()).await;
3553
3554 // Insert user message without any context (empty context vector)
3555 let message_id = thread.update(cx, |thread, cx| {
3556 thread.insert_user_message(
3557 "What is the best way to learn Rust?",
3558 ContextLoadResult::default(),
3559 None,
3560 Vec::new(),
3561 cx,
3562 )
3563 });
3564
3565 // Check content and context in message object
3566 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3567
3568 // Context should be empty when no files are included
3569 assert_eq!(message.role, Role::User);
3570 assert_eq!(message.segments.len(), 1);
3571 assert_eq!(
3572 message.segments[0],
3573 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3574 );
3575 assert_eq!(message.loaded_context.text, "");
3576
3577 // Check message in request
3578 let request = thread.update(cx, |thread, cx| {
3579 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3580 });
3581
3582 assert_eq!(request.messages.len(), 2);
3583 assert_eq!(
3584 request.messages[1].string_contents(),
3585 "What is the best way to learn Rust?"
3586 );
3587
3588 // Add second message, also without context
3589 let message2_id = thread.update(cx, |thread, cx| {
3590 thread.insert_user_message(
3591 "Are there any good books?",
3592 ContextLoadResult::default(),
3593 None,
3594 Vec::new(),
3595 cx,
3596 )
3597 });
3598
3599 let message2 =
3600 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3601 assert_eq!(message2.loaded_context.text, "");
3602
3603 // Check that both messages appear in the request
3604 let request = thread.update(cx, |thread, cx| {
3605 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3606 });
3607
3608 assert_eq!(request.messages.len(), 3);
3609 assert_eq!(
3610 request.messages[1].string_contents(),
3611 "What is the best way to learn Rust?"
3612 );
3613 assert_eq!(
3614 request.messages[2].string_contents(),
3615 "Are there any good books?"
3616 );
3617 }
3618
3619 #[gpui::test]
3620 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3621 init_test_settings(cx);
3622
3623 let project = create_test_project(
3624 cx,
3625 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3626 )
3627 .await;
3628
3629 let (_workspace, _thread_store, thread, context_store, model) =
3630 setup_test_environment(cx, project.clone()).await;
3631
3632 // Add a buffer to the context. This will be a tracked buffer
3633 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3634 .await
3635 .unwrap();
3636
3637 let context = context_store
3638 .read_with(cx, |store, _| store.context().next().cloned())
3639 .unwrap();
3640 let loaded_context = cx
3641 .update(|cx| load_context(vec![context], &project, &None, cx))
3642 .await;
3643
3644 // Insert user message and assistant response
3645 thread.update(cx, |thread, cx| {
3646 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3647 thread.insert_assistant_message(
3648 vec![MessageSegment::Text("This code prints 42.".into())],
3649 cx,
3650 );
3651 });
3652
3653 // We shouldn't have a stale buffer notification yet
3654 let notifications = thread.read_with(cx, |thread, _| {
3655 find_tool_uses(thread, "project_notifications")
3656 });
3657 assert!(
3658 notifications.is_empty(),
3659 "Should not have stale buffer notification before buffer is modified"
3660 );
3661
3662 // Modify the buffer
3663 buffer.update(cx, |buffer, cx| {
3664 buffer.edit(
3665 [(1..1, "\n println!(\"Added a new line\");\n")],
3666 None,
3667 cx,
3668 );
3669 });
3670
3671 // Insert another user message
3672 thread.update(cx, |thread, cx| {
3673 thread.insert_user_message(
3674 "What does the code do now?",
3675 ContextLoadResult::default(),
3676 None,
3677 Vec::new(),
3678 cx,
3679 )
3680 });
3681
3682 // Check for the stale buffer warning
3683 thread.update(cx, |thread, cx| {
3684 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3685 });
3686
3687 let notifications = thread.read_with(cx, |thread, _cx| {
3688 find_tool_uses(thread, "project_notifications")
3689 });
3690
3691 let [notification] = notifications.as_slice() else {
3692 panic!("Should have a `project_notifications` tool use");
3693 };
3694
3695 let Some(notification_content) = notification.content.to_str() else {
3696 panic!("`project_notifications` should return text");
3697 };
3698
3699 let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3700
3701 These files have changed since the last read:
3702 - code.rs
3703 "};
3704 assert_eq!(notification_content, expected_content);
3705
3706 // Insert another user message and flush notifications again
3707 thread.update(cx, |thread, cx| {
3708 thread.insert_user_message(
3709 "Can you tell me more?",
3710 ContextLoadResult::default(),
3711 None,
3712 Vec::new(),
3713 cx,
3714 )
3715 });
3716
3717 thread.update(cx, |thread, cx| {
3718 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3719 });
3720
3721 // There should be no new notifications (we already flushed one)
3722 let notifications = thread.read_with(cx, |thread, _cx| {
3723 find_tool_uses(thread, "project_notifications")
3724 });
3725
3726 assert_eq!(
3727 notifications.len(),
3728 1,
3729 "Should still have only one notification after second flush - no duplicates"
3730 );
3731 }
3732
3733 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3734 thread
3735 .messages()
3736 .flat_map(|message| {
3737 thread
3738 .tool_results_for_message(message.id)
3739 .into_iter()
3740 .filter(|result| result.tool_name == tool_name.into())
3741 .cloned()
3742 .collect::<Vec<_>>()
3743 })
3744 .collect()
3745 }
3746
3747 #[gpui::test]
3748 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3749 init_test_settings(cx);
3750
3751 let project = create_test_project(
3752 cx,
3753 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3754 )
3755 .await;
3756
3757 let (_workspace, thread_store, thread, _context_store, _model) =
3758 setup_test_environment(cx, project.clone()).await;
3759
3760 // Check that we are starting with the default profile
3761 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3762 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3763 assert_eq!(
3764 profile,
3765 AgentProfile::new(AgentProfileId::default(), tool_set)
3766 );
3767 }
3768
3769 #[gpui::test]
3770 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3771 init_test_settings(cx);
3772
3773 let project = create_test_project(
3774 cx,
3775 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3776 )
3777 .await;
3778
3779 let (_workspace, thread_store, thread, _context_store, _model) =
3780 setup_test_environment(cx, project.clone()).await;
3781
3782 // Profile gets serialized with default values
3783 let serialized = thread
3784 .update(cx, |thread, cx| thread.serialize(cx))
3785 .await
3786 .unwrap();
3787
3788 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3789
3790 let deserialized = cx.update(|cx| {
3791 thread.update(cx, |thread, cx| {
3792 Thread::deserialize(
3793 thread.id.clone(),
3794 serialized,
3795 thread.project.clone(),
3796 thread.tools.clone(),
3797 thread.prompt_builder.clone(),
3798 thread.project_context.clone(),
3799 None,
3800 cx,
3801 )
3802 })
3803 });
3804 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3805
3806 assert_eq!(
3807 deserialized.profile,
3808 AgentProfile::new(AgentProfileId::default(), tool_set)
3809 );
3810 }
3811
3812 #[gpui::test]
3813 async fn test_temperature_setting(cx: &mut TestAppContext) {
3814 init_test_settings(cx);
3815
3816 let project = create_test_project(
3817 cx,
3818 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3819 )
3820 .await;
3821
3822 let (_workspace, _thread_store, thread, _context_store, model) =
3823 setup_test_environment(cx, project.clone()).await;
3824
3825 // Both model and provider
3826 cx.update(|cx| {
3827 AgentSettings::override_global(
3828 AgentSettings {
3829 model_parameters: vec![LanguageModelParameters {
3830 provider: Some(model.provider_id().0.to_string().into()),
3831 model: Some(model.id().0.clone()),
3832 temperature: Some(0.66),
3833 }],
3834 ..AgentSettings::get_global(cx).clone()
3835 },
3836 cx,
3837 );
3838 });
3839
3840 let request = thread.update(cx, |thread, cx| {
3841 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3842 });
3843 assert_eq!(request.temperature, Some(0.66));
3844
3845 // Only model
3846 cx.update(|cx| {
3847 AgentSettings::override_global(
3848 AgentSettings {
3849 model_parameters: vec![LanguageModelParameters {
3850 provider: None,
3851 model: Some(model.id().0.clone()),
3852 temperature: Some(0.66),
3853 }],
3854 ..AgentSettings::get_global(cx).clone()
3855 },
3856 cx,
3857 );
3858 });
3859
3860 let request = thread.update(cx, |thread, cx| {
3861 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3862 });
3863 assert_eq!(request.temperature, Some(0.66));
3864
3865 // Only provider
3866 cx.update(|cx| {
3867 AgentSettings::override_global(
3868 AgentSettings {
3869 model_parameters: vec![LanguageModelParameters {
3870 provider: Some(model.provider_id().0.to_string().into()),
3871 model: None,
3872 temperature: Some(0.66),
3873 }],
3874 ..AgentSettings::get_global(cx).clone()
3875 },
3876 cx,
3877 );
3878 });
3879
3880 let request = thread.update(cx, |thread, cx| {
3881 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3882 });
3883 assert_eq!(request.temperature, Some(0.66));
3884
3885 // Same model name, different provider
3886 cx.update(|cx| {
3887 AgentSettings::override_global(
3888 AgentSettings {
3889 model_parameters: vec![LanguageModelParameters {
3890 provider: Some("anthropic".into()),
3891 model: Some(model.id().0.clone()),
3892 temperature: Some(0.66),
3893 }],
3894 ..AgentSettings::get_global(cx).clone()
3895 },
3896 cx,
3897 );
3898 });
3899
3900 let request = thread.update(cx, |thread, cx| {
3901 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3902 });
3903 assert_eq!(request.temperature, None);
3904 }
3905
3906 #[gpui::test]
3907 async fn test_thread_summary(cx: &mut TestAppContext) {
3908 init_test_settings(cx);
3909
3910 let project = create_test_project(cx, json!({})).await;
3911
3912 let (_, _thread_store, thread, _context_store, model) =
3913 setup_test_environment(cx, project.clone()).await;
3914
3915 // Initial state should be pending
3916 thread.read_with(cx, |thread, _| {
3917 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3918 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3919 });
3920
3921 // Manually setting the summary should not be allowed in this state
3922 thread.update(cx, |thread, cx| {
3923 thread.set_summary("This should not work", cx);
3924 });
3925
3926 thread.read_with(cx, |thread, _| {
3927 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3928 });
3929
3930 // Send a message
3931 thread.update(cx, |thread, cx| {
3932 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3933 thread.send_to_model(
3934 model.clone(),
3935 CompletionIntent::ThreadSummarization,
3936 None,
3937 cx,
3938 );
3939 });
3940
3941 let fake_model = model.as_fake();
3942 simulate_successful_response(&fake_model, cx);
3943
3944 // Should start generating summary when there are >= 2 messages
3945 thread.read_with(cx, |thread, _| {
3946 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3947 });
3948
3949 // Should not be able to set the summary while generating
3950 thread.update(cx, |thread, cx| {
3951 thread.set_summary("This should not work either", cx);
3952 });
3953
3954 thread.read_with(cx, |thread, _| {
3955 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3956 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3957 });
3958
3959 cx.run_until_parked();
3960 fake_model.stream_last_completion_response("Brief");
3961 fake_model.stream_last_completion_response(" Introduction");
3962 fake_model.end_last_completion_stream();
3963 cx.run_until_parked();
3964
3965 // Summary should be set
3966 thread.read_with(cx, |thread, _| {
3967 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3968 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3969 });
3970
3971 // Now we should be able to set a summary
3972 thread.update(cx, |thread, cx| {
3973 thread.set_summary("Brief Intro", cx);
3974 });
3975
3976 thread.read_with(cx, |thread, _| {
3977 assert_eq!(thread.summary().or_default(), "Brief Intro");
3978 });
3979
3980 // Test setting an empty summary (should default to DEFAULT)
3981 thread.update(cx, |thread, cx| {
3982 thread.set_summary("", cx);
3983 });
3984
3985 thread.read_with(cx, |thread, _| {
3986 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3987 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3988 });
3989 }
3990
3991 #[gpui::test]
3992 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3993 init_test_settings(cx);
3994
3995 let project = create_test_project(cx, json!({})).await;
3996
3997 let (_, _thread_store, thread, _context_store, model) =
3998 setup_test_environment(cx, project.clone()).await;
3999
4000 test_summarize_error(&model, &thread, cx);
4001
4002 // Now we should be able to set a summary
4003 thread.update(cx, |thread, cx| {
4004 thread.set_summary("Brief Intro", cx);
4005 });
4006
4007 thread.read_with(cx, |thread, _| {
4008 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4009 assert_eq!(thread.summary().or_default(), "Brief Intro");
4010 });
4011 }
4012
4013 #[gpui::test]
4014 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
4015 init_test_settings(cx);
4016
4017 let project = create_test_project(cx, json!({})).await;
4018
4019 let (_, _thread_store, thread, _context_store, model) =
4020 setup_test_environment(cx, project.clone()).await;
4021
4022 test_summarize_error(&model, &thread, cx);
4023
4024 // Sending another message should not trigger another summarize request
4025 thread.update(cx, |thread, cx| {
4026 thread.insert_user_message(
4027 "How are you?",
4028 ContextLoadResult::default(),
4029 None,
4030 vec![],
4031 cx,
4032 );
4033 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4034 });
4035
4036 let fake_model = model.as_fake();
4037 simulate_successful_response(&fake_model, cx);
4038
4039 thread.read_with(cx, |thread, _| {
4040 // State is still Error, not Generating
4041 assert!(matches!(thread.summary(), ThreadSummary::Error));
4042 });
4043
4044 // But the summarize request can be invoked manually
4045 thread.update(cx, |thread, cx| {
4046 thread.summarize(cx);
4047 });
4048
4049 thread.read_with(cx, |thread, _| {
4050 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4051 });
4052
4053 cx.run_until_parked();
4054 fake_model.stream_last_completion_response("A successful summary");
4055 fake_model.end_last_completion_stream();
4056 cx.run_until_parked();
4057
4058 thread.read_with(cx, |thread, _| {
4059 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4060 assert_eq!(thread.summary().or_default(), "A successful summary");
4061 });
4062 }
4063
4064 // Helper to create a model that returns errors
4065 enum TestError {
4066 Overloaded,
4067 InternalServerError,
4068 }
4069
4070 struct ErrorInjector {
4071 inner: Arc<FakeLanguageModel>,
4072 error_type: TestError,
4073 }
4074
4075 impl ErrorInjector {
4076 fn new(error_type: TestError) -> Self {
4077 Self {
4078 inner: Arc::new(FakeLanguageModel::default()),
4079 error_type,
4080 }
4081 }
4082 }
4083
4084 impl LanguageModel for ErrorInjector {
4085 fn id(&self) -> LanguageModelId {
4086 self.inner.id()
4087 }
4088
4089 fn name(&self) -> LanguageModelName {
4090 self.inner.name()
4091 }
4092
4093 fn provider_id(&self) -> LanguageModelProviderId {
4094 self.inner.provider_id()
4095 }
4096
4097 fn provider_name(&self) -> LanguageModelProviderName {
4098 self.inner.provider_name()
4099 }
4100
4101 fn supports_tools(&self) -> bool {
4102 self.inner.supports_tools()
4103 }
4104
4105 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4106 self.inner.supports_tool_choice(choice)
4107 }
4108
4109 fn supports_images(&self) -> bool {
4110 self.inner.supports_images()
4111 }
4112
4113 fn telemetry_id(&self) -> String {
4114 self.inner.telemetry_id()
4115 }
4116
4117 fn max_token_count(&self) -> u64 {
4118 self.inner.max_token_count()
4119 }
4120
4121 fn count_tokens(
4122 &self,
4123 request: LanguageModelRequest,
4124 cx: &App,
4125 ) -> BoxFuture<'static, Result<u64>> {
4126 self.inner.count_tokens(request, cx)
4127 }
4128
4129 fn stream_completion(
4130 &self,
4131 _request: LanguageModelRequest,
4132 _cx: &AsyncApp,
4133 ) -> BoxFuture<
4134 'static,
4135 Result<
4136 BoxStream<
4137 'static,
4138 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4139 >,
4140 LanguageModelCompletionError,
4141 >,
4142 > {
4143 let error = match self.error_type {
4144 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4145 provider: self.provider_name(),
4146 retry_after: None,
4147 },
4148 TestError::InternalServerError => {
4149 LanguageModelCompletionError::ApiInternalServerError {
4150 provider: self.provider_name(),
4151 message: "I'm a teapot orbiting the sun".to_string(),
4152 }
4153 }
4154 };
4155 async move {
4156 let stream = futures::stream::once(async move { Err(error) });
4157 Ok(stream.boxed())
4158 }
4159 .boxed()
4160 }
4161
4162 fn as_fake(&self) -> &FakeLanguageModel {
4163 &self.inner
4164 }
4165 }
4166
4167 #[gpui::test]
4168 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4169 init_test_settings(cx);
4170
4171 let project = create_test_project(cx, json!({})).await;
4172 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4173
4174 // Create model that returns overloaded error
4175 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4176
4177 // Insert a user message
4178 thread.update(cx, |thread, cx| {
4179 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4180 });
4181
4182 // Start completion
4183 thread.update(cx, |thread, cx| {
4184 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4185 });
4186
4187 cx.run_until_parked();
4188
4189 thread.read_with(cx, |thread, _| {
4190 assert!(thread.retry_state.is_some(), "Should have retry state");
4191 let retry_state = thread.retry_state.as_ref().unwrap();
4192 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4193 assert_eq!(
4194 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4195 "Should have default max attempts"
4196 );
4197 });
4198
4199 // Check that a retry message was added
4200 thread.read_with(cx, |thread, _| {
4201 let mut messages = thread.messages();
4202 assert!(
4203 messages.any(|msg| {
4204 msg.role == Role::System
4205 && msg.ui_only
4206 && msg.segments.iter().any(|seg| {
4207 if let MessageSegment::Text(text) = seg {
4208 text.contains("overloaded")
4209 && text
4210 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4211 } else {
4212 false
4213 }
4214 })
4215 }),
4216 "Should have added a system retry message"
4217 );
4218 });
4219
4220 let retry_count = thread.update(cx, |thread, _| {
4221 thread
4222 .messages
4223 .iter()
4224 .filter(|m| {
4225 m.ui_only
4226 && m.segments.iter().any(|s| {
4227 if let MessageSegment::Text(text) = s {
4228 text.contains("Retrying") && text.contains("seconds")
4229 } else {
4230 false
4231 }
4232 })
4233 })
4234 .count()
4235 });
4236
4237 assert_eq!(retry_count, 1, "Should have one retry message");
4238 }
4239
4240 #[gpui::test]
4241 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4242 init_test_settings(cx);
4243
4244 let project = create_test_project(cx, json!({})).await;
4245 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4246
4247 // Create model that returns internal server error
4248 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4249
4250 // Insert a user message
4251 thread.update(cx, |thread, cx| {
4252 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4253 });
4254
4255 // Start completion
4256 thread.update(cx, |thread, cx| {
4257 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4258 });
4259
4260 cx.run_until_parked();
4261
4262 // Check retry state on thread
4263 thread.read_with(cx, |thread, _| {
4264 assert!(thread.retry_state.is_some(), "Should have retry state");
4265 let retry_state = thread.retry_state.as_ref().unwrap();
4266 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4267 assert_eq!(
4268 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4269 "Should have correct max attempts"
4270 );
4271 });
4272
4273 // Check that a retry message was added with provider name
4274 thread.read_with(cx, |thread, _| {
4275 let mut messages = thread.messages();
4276 assert!(
4277 messages.any(|msg| {
4278 msg.role == Role::System
4279 && msg.ui_only
4280 && msg.segments.iter().any(|seg| {
4281 if let MessageSegment::Text(text) = seg {
4282 text.contains("internal")
4283 && text.contains("Fake")
4284 && text
4285 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4286 } else {
4287 false
4288 }
4289 })
4290 }),
4291 "Should have added a system retry message with provider name"
4292 );
4293 });
4294
4295 // Count retry messages
4296 let retry_count = thread.update(cx, |thread, _| {
4297 thread
4298 .messages
4299 .iter()
4300 .filter(|m| {
4301 m.ui_only
4302 && m.segments.iter().any(|s| {
4303 if let MessageSegment::Text(text) = s {
4304 text.contains("Retrying") && text.contains("seconds")
4305 } else {
4306 false
4307 }
4308 })
4309 })
4310 .count()
4311 });
4312
4313 assert_eq!(retry_count, 1, "Should have one retry message");
4314 }
4315
4316 #[gpui::test]
4317 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4318 init_test_settings(cx);
4319
4320 let project = create_test_project(cx, json!({})).await;
4321 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4322
4323 // Create model that returns overloaded error
4324 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4325
4326 // Insert a user message
4327 thread.update(cx, |thread, cx| {
4328 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4329 });
4330
4331 // Track retry events and completion count
4332 // Track completion events
4333 let completion_count = Arc::new(Mutex::new(0));
4334 let completion_count_clone = completion_count.clone();
4335
4336 let _subscription = thread.update(cx, |_, cx| {
4337 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4338 if let ThreadEvent::NewRequest = event {
4339 *completion_count_clone.lock() += 1;
4340 }
4341 })
4342 });
4343
4344 // First attempt
4345 thread.update(cx, |thread, cx| {
4346 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4347 });
4348 cx.run_until_parked();
4349
4350 // Should have scheduled first retry - count retry messages
4351 let retry_count = thread.update(cx, |thread, _| {
4352 thread
4353 .messages
4354 .iter()
4355 .filter(|m| {
4356 m.ui_only
4357 && m.segments.iter().any(|s| {
4358 if let MessageSegment::Text(text) = s {
4359 text.contains("Retrying") && text.contains("seconds")
4360 } else {
4361 false
4362 }
4363 })
4364 })
4365 .count()
4366 });
4367 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4368
4369 // Check retry state
4370 thread.read_with(cx, |thread, _| {
4371 assert!(thread.retry_state.is_some(), "Should have retry state");
4372 let retry_state = thread.retry_state.as_ref().unwrap();
4373 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4374 });
4375
4376 // Advance clock for first retry
4377 cx.executor()
4378 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4379 cx.run_until_parked();
4380
4381 // Should have scheduled second retry - count retry messages
4382 let retry_count = thread.update(cx, |thread, _| {
4383 thread
4384 .messages
4385 .iter()
4386 .filter(|m| {
4387 m.ui_only
4388 && m.segments.iter().any(|s| {
4389 if let MessageSegment::Text(text) = s {
4390 text.contains("Retrying") && text.contains("seconds")
4391 } else {
4392 false
4393 }
4394 })
4395 })
4396 .count()
4397 });
4398 assert_eq!(retry_count, 2, "Should have scheduled second retry");
4399
4400 // Check retry state updated
4401 thread.read_with(cx, |thread, _| {
4402 assert!(thread.retry_state.is_some(), "Should have retry state");
4403 let retry_state = thread.retry_state.as_ref().unwrap();
4404 assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4405 assert_eq!(
4406 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4407 "Should have correct max attempts"
4408 );
4409 });
4410
4411 // Advance clock for second retry (exponential backoff)
4412 cx.executor()
4413 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4414 cx.run_until_parked();
4415
4416 // Should have scheduled third retry
4417 // Count all retry messages now
4418 let retry_count = thread.update(cx, |thread, _| {
4419 thread
4420 .messages
4421 .iter()
4422 .filter(|m| {
4423 m.ui_only
4424 && m.segments.iter().any(|s| {
4425 if let MessageSegment::Text(text) = s {
4426 text.contains("Retrying") && text.contains("seconds")
4427 } else {
4428 false
4429 }
4430 })
4431 })
4432 .count()
4433 });
4434 assert_eq!(
4435 retry_count, MAX_RETRY_ATTEMPTS as usize,
4436 "Should have scheduled third retry"
4437 );
4438
4439 // Check retry state updated
4440 thread.read_with(cx, |thread, _| {
4441 assert!(thread.retry_state.is_some(), "Should have retry state");
4442 let retry_state = thread.retry_state.as_ref().unwrap();
4443 assert_eq!(
4444 retry_state.attempt, MAX_RETRY_ATTEMPTS,
4445 "Should be at max retry attempt"
4446 );
4447 assert_eq!(
4448 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4449 "Should have correct max attempts"
4450 );
4451 });
4452
4453 // Advance clock for third retry (exponential backoff)
4454 cx.executor()
4455 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4456 cx.run_until_parked();
4457
4458 // No more retries should be scheduled after clock was advanced.
4459 let retry_count = thread.update(cx, |thread, _| {
4460 thread
4461 .messages
4462 .iter()
4463 .filter(|m| {
4464 m.ui_only
4465 && m.segments.iter().any(|s| {
4466 if let MessageSegment::Text(text) = s {
4467 text.contains("Retrying") && text.contains("seconds")
4468 } else {
4469 false
4470 }
4471 })
4472 })
4473 .count()
4474 });
4475 assert_eq!(
4476 retry_count, MAX_RETRY_ATTEMPTS as usize,
4477 "Should not exceed max retries"
4478 );
4479
4480 // Final completion count should be initial + max retries
4481 assert_eq!(
4482 *completion_count.lock(),
4483 (MAX_RETRY_ATTEMPTS + 1) as usize,
4484 "Should have made initial + max retry attempts"
4485 );
4486 }
4487
4488 #[gpui::test]
4489 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4490 init_test_settings(cx);
4491
4492 let project = create_test_project(cx, json!({})).await;
4493 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4494
4495 // Create model that returns overloaded error
4496 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4497
4498 // Insert a user message
4499 thread.update(cx, |thread, cx| {
4500 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4501 });
4502
4503 // Track events
4504 let retries_failed = Arc::new(Mutex::new(false));
4505 let retries_failed_clone = retries_failed.clone();
4506
4507 let _subscription = thread.update(cx, |_, cx| {
4508 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4509 if let ThreadEvent::RetriesFailed { .. } = event {
4510 *retries_failed_clone.lock() = true;
4511 }
4512 })
4513 });
4514
4515 // Start initial completion
4516 thread.update(cx, |thread, cx| {
4517 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4518 });
4519 cx.run_until_parked();
4520
4521 // Advance through all retries
4522 for i in 0..MAX_RETRY_ATTEMPTS {
4523 let delay = if i == 0 {
4524 BASE_RETRY_DELAY_SECS
4525 } else {
4526 BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4527 };
4528 cx.executor().advance_clock(Duration::from_secs(delay));
4529 cx.run_until_parked();
4530 }
4531
4532 // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4533 // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4534 let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4535 cx.executor()
4536 .advance_clock(Duration::from_secs(final_delay));
4537 cx.run_until_parked();
4538
4539 let retry_count = thread.update(cx, |thread, _| {
4540 thread
4541 .messages
4542 .iter()
4543 .filter(|m| {
4544 m.ui_only
4545 && m.segments.iter().any(|s| {
4546 if let MessageSegment::Text(text) = s {
4547 text.contains("Retrying") && text.contains("seconds")
4548 } else {
4549 false
4550 }
4551 })
4552 })
4553 .count()
4554 });
4555
4556 // After max retries, should emit RetriesFailed event
4557 assert_eq!(
4558 retry_count, MAX_RETRY_ATTEMPTS as usize,
4559 "Should have attempted max retries"
4560 );
4561 assert!(
4562 *retries_failed.lock(),
4563 "Should emit RetriesFailed event after max retries exceeded"
4564 );
4565
4566 // Retry state should be cleared
4567 thread.read_with(cx, |thread, _| {
4568 assert!(
4569 thread.retry_state.is_none(),
4570 "Retry state should be cleared after max retries"
4571 );
4572
4573 // Verify we have the expected number of retry messages
4574 let retry_messages = thread
4575 .messages
4576 .iter()
4577 .filter(|msg| msg.ui_only && msg.role == Role::System)
4578 .count();
4579 assert_eq!(
4580 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4581 "Should have one retry message per attempt"
4582 );
4583 });
4584 }
4585
4586 #[gpui::test]
4587 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4588 init_test_settings(cx);
4589
4590 let project = create_test_project(cx, json!({})).await;
4591 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4592
4593 // We'll use a wrapper to switch behavior after first failure
4594 struct RetryTestModel {
4595 inner: Arc<FakeLanguageModel>,
4596 failed_once: Arc<Mutex<bool>>,
4597 }
4598
4599 impl LanguageModel for RetryTestModel {
4600 fn id(&self) -> LanguageModelId {
4601 self.inner.id()
4602 }
4603
4604 fn name(&self) -> LanguageModelName {
4605 self.inner.name()
4606 }
4607
4608 fn provider_id(&self) -> LanguageModelProviderId {
4609 self.inner.provider_id()
4610 }
4611
4612 fn provider_name(&self) -> LanguageModelProviderName {
4613 self.inner.provider_name()
4614 }
4615
4616 fn supports_tools(&self) -> bool {
4617 self.inner.supports_tools()
4618 }
4619
4620 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4621 self.inner.supports_tool_choice(choice)
4622 }
4623
4624 fn supports_images(&self) -> bool {
4625 self.inner.supports_images()
4626 }
4627
4628 fn telemetry_id(&self) -> String {
4629 self.inner.telemetry_id()
4630 }
4631
4632 fn max_token_count(&self) -> u64 {
4633 self.inner.max_token_count()
4634 }
4635
4636 fn count_tokens(
4637 &self,
4638 request: LanguageModelRequest,
4639 cx: &App,
4640 ) -> BoxFuture<'static, Result<u64>> {
4641 self.inner.count_tokens(request, cx)
4642 }
4643
4644 fn stream_completion(
4645 &self,
4646 request: LanguageModelRequest,
4647 cx: &AsyncApp,
4648 ) -> BoxFuture<
4649 'static,
4650 Result<
4651 BoxStream<
4652 'static,
4653 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4654 >,
4655 LanguageModelCompletionError,
4656 >,
4657 > {
4658 if !*self.failed_once.lock() {
4659 *self.failed_once.lock() = true;
4660 let provider = self.provider_name();
4661 // Return error on first attempt
4662 let stream = futures::stream::once(async move {
4663 Err(LanguageModelCompletionError::ServerOverloaded {
4664 provider,
4665 retry_after: None,
4666 })
4667 });
4668 async move { Ok(stream.boxed()) }.boxed()
4669 } else {
4670 // Succeed on retry
4671 self.inner.stream_completion(request, cx)
4672 }
4673 }
4674
4675 fn as_fake(&self) -> &FakeLanguageModel {
4676 &self.inner
4677 }
4678 }
4679
4680 let model = Arc::new(RetryTestModel {
4681 inner: Arc::new(FakeLanguageModel::default()),
4682 failed_once: Arc::new(Mutex::new(false)),
4683 });
4684
4685 // Insert a user message
4686 thread.update(cx, |thread, cx| {
4687 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4688 });
4689
4690 // Track message deletions
4691 // Track when retry completes successfully
4692 let retry_completed = Arc::new(Mutex::new(false));
4693 let retry_completed_clone = retry_completed.clone();
4694
4695 let _subscription = thread.update(cx, |_, cx| {
4696 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4697 if let ThreadEvent::StreamedCompletion = event {
4698 *retry_completed_clone.lock() = true;
4699 }
4700 })
4701 });
4702
4703 // Start completion
4704 thread.update(cx, |thread, cx| {
4705 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4706 });
4707 cx.run_until_parked();
4708
4709 // Get the retry message ID
4710 let retry_message_id = thread.read_with(cx, |thread, _| {
4711 thread
4712 .messages()
4713 .find(|msg| msg.role == Role::System && msg.ui_only)
4714 .map(|msg| msg.id)
4715 .expect("Should have a retry message")
4716 });
4717
4718 // Wait for retry
4719 cx.executor()
4720 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4721 cx.run_until_parked();
4722
4723 // Stream some successful content
4724 let fake_model = model.as_fake();
4725 // After the retry, there should be a new pending completion
4726 let pending = fake_model.pending_completions();
4727 assert!(
4728 !pending.is_empty(),
4729 "Should have a pending completion after retry"
4730 );
4731 fake_model.stream_completion_response(&pending[0], "Success!");
4732 fake_model.end_completion_stream(&pending[0]);
4733 cx.run_until_parked();
4734
4735 // Check that the retry completed successfully
4736 assert!(
4737 *retry_completed.lock(),
4738 "Retry should have completed successfully"
4739 );
4740
4741 // Retry message should still exist but be marked as ui_only
4742 thread.read_with(cx, |thread, _| {
4743 let retry_msg = thread
4744 .message(retry_message_id)
4745 .expect("Retry message should still exist");
4746 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4747 assert_eq!(
4748 retry_msg.role,
4749 Role::System,
4750 "Retry message should have System role"
4751 );
4752 });
4753 }
4754
4755 #[gpui::test]
4756 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4757 init_test_settings(cx);
4758
4759 let project = create_test_project(cx, json!({})).await;
4760 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4761
4762 // Create a model that fails once then succeeds
4763 struct FailOnceModel {
4764 inner: Arc<FakeLanguageModel>,
4765 failed_once: Arc<Mutex<bool>>,
4766 }
4767
4768 impl LanguageModel for FailOnceModel {
4769 fn id(&self) -> LanguageModelId {
4770 self.inner.id()
4771 }
4772
4773 fn name(&self) -> LanguageModelName {
4774 self.inner.name()
4775 }
4776
4777 fn provider_id(&self) -> LanguageModelProviderId {
4778 self.inner.provider_id()
4779 }
4780
4781 fn provider_name(&self) -> LanguageModelProviderName {
4782 self.inner.provider_name()
4783 }
4784
4785 fn supports_tools(&self) -> bool {
4786 self.inner.supports_tools()
4787 }
4788
4789 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4790 self.inner.supports_tool_choice(choice)
4791 }
4792
4793 fn supports_images(&self) -> bool {
4794 self.inner.supports_images()
4795 }
4796
4797 fn telemetry_id(&self) -> String {
4798 self.inner.telemetry_id()
4799 }
4800
4801 fn max_token_count(&self) -> u64 {
4802 self.inner.max_token_count()
4803 }
4804
4805 fn count_tokens(
4806 &self,
4807 request: LanguageModelRequest,
4808 cx: &App,
4809 ) -> BoxFuture<'static, Result<u64>> {
4810 self.inner.count_tokens(request, cx)
4811 }
4812
4813 fn stream_completion(
4814 &self,
4815 request: LanguageModelRequest,
4816 cx: &AsyncApp,
4817 ) -> BoxFuture<
4818 'static,
4819 Result<
4820 BoxStream<
4821 'static,
4822 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4823 >,
4824 LanguageModelCompletionError,
4825 >,
4826 > {
4827 if !*self.failed_once.lock() {
4828 *self.failed_once.lock() = true;
4829 let provider = self.provider_name();
4830 // Return error on first attempt
4831 let stream = futures::stream::once(async move {
4832 Err(LanguageModelCompletionError::ServerOverloaded {
4833 provider,
4834 retry_after: None,
4835 })
4836 });
4837 async move { Ok(stream.boxed()) }.boxed()
4838 } else {
4839 // Succeed on retry
4840 self.inner.stream_completion(request, cx)
4841 }
4842 }
4843 }
4844
4845 let fail_once_model = Arc::new(FailOnceModel {
4846 inner: Arc::new(FakeLanguageModel::default()),
4847 failed_once: Arc::new(Mutex::new(false)),
4848 });
4849
4850 // Insert a user message
4851 thread.update(cx, |thread, cx| {
4852 thread.insert_user_message(
4853 "Test message",
4854 ContextLoadResult::default(),
4855 None,
4856 vec![],
4857 cx,
4858 );
4859 });
4860
4861 // Start completion with fail-once model
4862 thread.update(cx, |thread, cx| {
4863 thread.send_to_model(
4864 fail_once_model.clone(),
4865 CompletionIntent::UserPrompt,
4866 None,
4867 cx,
4868 );
4869 });
4870
4871 cx.run_until_parked();
4872
4873 // Verify retry state exists after first failure
4874 thread.read_with(cx, |thread, _| {
4875 assert!(
4876 thread.retry_state.is_some(),
4877 "Should have retry state after failure"
4878 );
4879 });
4880
4881 // Wait for retry delay
4882 cx.executor()
4883 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4884 cx.run_until_parked();
4885
4886 // The retry should now use our FailOnceModel which should succeed
4887 // We need to help the FakeLanguageModel complete the stream
4888 let inner_fake = fail_once_model.inner.clone();
4889
4890 // Wait a bit for the retry to start
4891 cx.run_until_parked();
4892
4893 // Check for pending completions and complete them
4894 if let Some(pending) = inner_fake.pending_completions().first() {
4895 inner_fake.stream_completion_response(pending, "Success!");
4896 inner_fake.end_completion_stream(pending);
4897 }
4898 cx.run_until_parked();
4899
4900 thread.read_with(cx, |thread, _| {
4901 assert!(
4902 thread.retry_state.is_none(),
4903 "Retry state should be cleared after successful completion"
4904 );
4905
4906 let has_assistant_message = thread
4907 .messages
4908 .iter()
4909 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4910 assert!(
4911 has_assistant_message,
4912 "Should have an assistant message after successful retry"
4913 );
4914 });
4915 }
4916
4917 #[gpui::test]
4918 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4919 init_test_settings(cx);
4920
4921 let project = create_test_project(cx, json!({})).await;
4922 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4923
4924 // Create a model that returns rate limit error with retry_after
4925 struct RateLimitModel {
4926 inner: Arc<FakeLanguageModel>,
4927 }
4928
4929 impl LanguageModel for RateLimitModel {
4930 fn id(&self) -> LanguageModelId {
4931 self.inner.id()
4932 }
4933
4934 fn name(&self) -> LanguageModelName {
4935 self.inner.name()
4936 }
4937
4938 fn provider_id(&self) -> LanguageModelProviderId {
4939 self.inner.provider_id()
4940 }
4941
4942 fn provider_name(&self) -> LanguageModelProviderName {
4943 self.inner.provider_name()
4944 }
4945
4946 fn supports_tools(&self) -> bool {
4947 self.inner.supports_tools()
4948 }
4949
4950 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4951 self.inner.supports_tool_choice(choice)
4952 }
4953
4954 fn supports_images(&self) -> bool {
4955 self.inner.supports_images()
4956 }
4957
4958 fn telemetry_id(&self) -> String {
4959 self.inner.telemetry_id()
4960 }
4961
4962 fn max_token_count(&self) -> u64 {
4963 self.inner.max_token_count()
4964 }
4965
4966 fn count_tokens(
4967 &self,
4968 request: LanguageModelRequest,
4969 cx: &App,
4970 ) -> BoxFuture<'static, Result<u64>> {
4971 self.inner.count_tokens(request, cx)
4972 }
4973
4974 fn stream_completion(
4975 &self,
4976 _request: LanguageModelRequest,
4977 _cx: &AsyncApp,
4978 ) -> BoxFuture<
4979 'static,
4980 Result<
4981 BoxStream<
4982 'static,
4983 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4984 >,
4985 LanguageModelCompletionError,
4986 >,
4987 > {
4988 let provider = self.provider_name();
4989 async move {
4990 let stream = futures::stream::once(async move {
4991 Err(LanguageModelCompletionError::RateLimitExceeded {
4992 provider,
4993 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4994 })
4995 });
4996 Ok(stream.boxed())
4997 }
4998 .boxed()
4999 }
5000
5001 fn as_fake(&self) -> &FakeLanguageModel {
5002 &self.inner
5003 }
5004 }
5005
5006 let model = Arc::new(RateLimitModel {
5007 inner: Arc::new(FakeLanguageModel::default()),
5008 });
5009
5010 // Insert a user message
5011 thread.update(cx, |thread, cx| {
5012 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5013 });
5014
5015 // Start completion
5016 thread.update(cx, |thread, cx| {
5017 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5018 });
5019
5020 cx.run_until_parked();
5021
5022 let retry_count = thread.update(cx, |thread, _| {
5023 thread
5024 .messages
5025 .iter()
5026 .filter(|m| {
5027 m.ui_only
5028 && m.segments.iter().any(|s| {
5029 if let MessageSegment::Text(text) = s {
5030 text.contains("rate limit exceeded")
5031 } else {
5032 false
5033 }
5034 })
5035 })
5036 .count()
5037 });
5038 assert_eq!(retry_count, 1, "Should have scheduled one retry");
5039
5040 thread.read_with(cx, |thread, _| {
5041 assert!(
5042 thread.retry_state.is_none(),
5043 "Rate limit errors should not set retry_state"
5044 );
5045 });
5046
5047 // Verify we have one retry message
5048 thread.read_with(cx, |thread, _| {
5049 let retry_messages = thread
5050 .messages
5051 .iter()
5052 .filter(|msg| {
5053 msg.ui_only
5054 && msg.segments.iter().any(|seg| {
5055 if let MessageSegment::Text(text) = seg {
5056 text.contains("rate limit exceeded")
5057 } else {
5058 false
5059 }
5060 })
5061 })
5062 .count();
5063 assert_eq!(
5064 retry_messages, 1,
5065 "Should have one rate limit retry message"
5066 );
5067 });
5068
5069 // Check that retry message doesn't include attempt count
5070 thread.read_with(cx, |thread, _| {
5071 let retry_message = thread
5072 .messages
5073 .iter()
5074 .find(|msg| msg.role == Role::System && msg.ui_only)
5075 .expect("Should have a retry message");
5076
5077 // Check that the message doesn't contain attempt count
5078 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5079 assert!(
5080 !text.contains("attempt"),
5081 "Rate limit retry message should not contain attempt count"
5082 );
5083 assert!(
5084 text.contains(&format!(
5085 "Retrying in {} seconds",
5086 TEST_RATE_LIMIT_RETRY_SECS
5087 )),
5088 "Rate limit retry message should contain retry delay"
5089 );
5090 }
5091 });
5092 }
5093
5094 #[gpui::test]
5095 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5096 init_test_settings(cx);
5097
5098 let project = create_test_project(cx, json!({})).await;
5099 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5100
5101 // Insert a regular user message
5102 thread.update(cx, |thread, cx| {
5103 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5104 });
5105
5106 // Insert a UI-only message (like our retry notifications)
5107 thread.update(cx, |thread, cx| {
5108 let id = thread.next_message_id.post_inc();
5109 thread.messages.push(Message {
5110 id,
5111 role: Role::System,
5112 segments: vec![MessageSegment::Text(
5113 "This is a UI-only message that should not be sent to the model".to_string(),
5114 )],
5115 loaded_context: LoadedContext::default(),
5116 creases: Vec::new(),
5117 is_hidden: true,
5118 ui_only: true,
5119 });
5120 cx.emit(ThreadEvent::MessageAdded(id));
5121 });
5122
5123 // Insert another regular message
5124 thread.update(cx, |thread, cx| {
5125 thread.insert_user_message(
5126 "How are you?",
5127 ContextLoadResult::default(),
5128 None,
5129 vec![],
5130 cx,
5131 );
5132 });
5133
5134 // Generate the completion request
5135 let request = thread.update(cx, |thread, cx| {
5136 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5137 });
5138
5139 // Verify that the request only contains non-UI-only messages
5140 // Should have system prompt + 2 user messages, but not the UI-only message
5141 let user_messages: Vec<_> = request
5142 .messages
5143 .iter()
5144 .filter(|msg| msg.role == Role::User)
5145 .collect();
5146 assert_eq!(
5147 user_messages.len(),
5148 2,
5149 "Should have exactly 2 user messages"
5150 );
5151
5152 // Verify the UI-only content is not present anywhere in the request
5153 let request_text = request
5154 .messages
5155 .iter()
5156 .flat_map(|msg| &msg.content)
5157 .filter_map(|content| match content {
5158 MessageContent::Text(text) => Some(text.as_str()),
5159 _ => None,
5160 })
5161 .collect::<String>();
5162
5163 assert!(
5164 !request_text.contains("UI-only message"),
5165 "UI-only message content should not be in the request"
5166 );
5167
5168 // Verify the thread still has all 3 messages (including UI-only)
5169 thread.read_with(cx, |thread, _| {
5170 assert_eq!(
5171 thread.messages().count(),
5172 3,
5173 "Thread should have 3 messages"
5174 );
5175 assert_eq!(
5176 thread.messages().filter(|m| m.ui_only).count(),
5177 1,
5178 "Thread should have 1 UI-only message"
5179 );
5180 });
5181
5182 // Verify that UI-only messages are not serialized
5183 let serialized = thread
5184 .update(cx, |thread, cx| thread.serialize(cx))
5185 .await
5186 .unwrap();
5187 assert_eq!(
5188 serialized.messages.len(),
5189 2,
5190 "Serialized thread should only have 2 messages (no UI-only)"
5191 );
5192 }
5193
5194 #[gpui::test]
5195 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5196 init_test_settings(cx);
5197
5198 let project = create_test_project(cx, json!({})).await;
5199 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5200
5201 // Create model that returns overloaded error
5202 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5203
5204 // Insert a user message
5205 thread.update(cx, |thread, cx| {
5206 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5207 });
5208
5209 // Start completion
5210 thread.update(cx, |thread, cx| {
5211 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5212 });
5213
5214 cx.run_until_parked();
5215
5216 // Verify retry was scheduled by checking for retry message
5217 let has_retry_message = thread.read_with(cx, |thread, _| {
5218 thread.messages.iter().any(|m| {
5219 m.ui_only
5220 && m.segments.iter().any(|s| {
5221 if let MessageSegment::Text(text) = s {
5222 text.contains("Retrying") && text.contains("seconds")
5223 } else {
5224 false
5225 }
5226 })
5227 })
5228 });
5229 assert!(has_retry_message, "Should have scheduled a retry");
5230
5231 // Cancel the completion before the retry happens
5232 thread.update(cx, |thread, cx| {
5233 thread.cancel_last_completion(None, cx);
5234 });
5235
5236 cx.run_until_parked();
5237
5238 // The retry should not have happened - no pending completions
5239 let fake_model = model.as_fake();
5240 assert_eq!(
5241 fake_model.pending_completions().len(),
5242 0,
5243 "Should have no pending completions after cancellation"
5244 );
5245
5246 // Verify the retry was cancelled by checking retry state
5247 thread.read_with(cx, |thread, _| {
5248 if let Some(retry_state) = &thread.retry_state {
5249 panic!(
5250 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5251 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5252 );
5253 }
5254 });
5255 }
5256
5257 fn test_summarize_error(
5258 model: &Arc<dyn LanguageModel>,
5259 thread: &Entity<Thread>,
5260 cx: &mut TestAppContext,
5261 ) {
5262 thread.update(cx, |thread, cx| {
5263 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5264 thread.send_to_model(
5265 model.clone(),
5266 CompletionIntent::ThreadSummarization,
5267 None,
5268 cx,
5269 );
5270 });
5271
5272 let fake_model = model.as_fake();
5273 simulate_successful_response(&fake_model, cx);
5274
5275 thread.read_with(cx, |thread, _| {
5276 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5277 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5278 });
5279
5280 // Simulate summary request ending
5281 cx.run_until_parked();
5282 fake_model.end_last_completion_stream();
5283 cx.run_until_parked();
5284
5285 // State is set to Error and default message
5286 thread.read_with(cx, |thread, _| {
5287 assert!(matches!(thread.summary(), ThreadSummary::Error));
5288 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5289 });
5290 }
5291
5292 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5293 cx.run_until_parked();
5294 fake_model.stream_last_completion_response("Assistant response");
5295 fake_model.end_last_completion_stream();
5296 cx.run_until_parked();
5297 }
5298
5299 fn init_test_settings(cx: &mut TestAppContext) {
5300 cx.update(|cx| {
5301 let settings_store = SettingsStore::test(cx);
5302 cx.set_global(settings_store);
5303 language::init(cx);
5304 Project::init_settings(cx);
5305 AgentSettings::register(cx);
5306 prompt_store::init(cx);
5307 thread_store::init(cx);
5308 workspace::init_settings(cx);
5309 language_model::init_settings(cx);
5310 ThemeSettings::register(cx);
5311 ToolRegistry::default_global(cx);
5312 assistant_tool::init(cx);
5313
5314 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5315 http_client::FakeHttpClient::with_200_response(),
5316 "http://localhost".to_string(),
5317 None,
5318 ));
5319 assistant_tools::init(http_client, cx);
5320 });
5321 }
5322
5323 // Helper to create a test project with test files
5324 async fn create_test_project(
5325 cx: &mut TestAppContext,
5326 files: serde_json::Value,
5327 ) -> Entity<Project> {
5328 let fs = FakeFs::new(cx.executor());
5329 fs.insert_tree(path!("/test"), files).await;
5330 Project::test(fs, [path!("/test").as_ref()], cx).await
5331 }
5332
5333 async fn setup_test_environment(
5334 cx: &mut TestAppContext,
5335 project: Entity<Project>,
5336 ) -> (
5337 Entity<Workspace>,
5338 Entity<ThreadStore>,
5339 Entity<Thread>,
5340 Entity<ContextStore>,
5341 Arc<dyn LanguageModel>,
5342 ) {
5343 let (workspace, cx) =
5344 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5345
5346 let thread_store = cx
5347 .update(|_, cx| {
5348 ThreadStore::load(
5349 project.clone(),
5350 cx.new(|_| ToolWorkingSet::default()),
5351 None,
5352 Arc::new(PromptBuilder::new(None).unwrap()),
5353 cx,
5354 )
5355 })
5356 .await
5357 .unwrap();
5358
5359 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5360 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5361
5362 let provider = Arc::new(FakeLanguageModelProvider);
5363 let model = provider.test_model();
5364 let model: Arc<dyn LanguageModel> = Arc::new(model);
5365
5366 cx.update(|_, cx| {
5367 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5368 registry.set_default_model(
5369 Some(ConfiguredModel {
5370 provider: provider.clone(),
5371 model: model.clone(),
5372 }),
5373 cx,
5374 );
5375 registry.set_thread_summary_model(
5376 Some(ConfiguredModel {
5377 provider,
5378 model: model.clone(),
5379 }),
5380 cx,
5381 );
5382 })
5383 });
5384
5385 (workspace, thread_store, thread, context_store, model)
5386 }
5387
5388 async fn add_file_to_context(
5389 project: &Entity<Project>,
5390 context_store: &Entity<ContextStore>,
5391 path: &str,
5392 cx: &mut TestAppContext,
5393 ) -> Result<Entity<language::Buffer>> {
5394 let buffer_path = project
5395 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5396 .unwrap();
5397
5398 let buffer = project
5399 .update(cx, |project, cx| {
5400 project.open_buffer(buffer_path.clone(), cx)
5401 })
5402 .await
5403 .unwrap();
5404
5405 context_store.update(cx, |context_store, cx| {
5406 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5407 });
5408
5409 Ok(buffer)
5410 }
5411}