1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::HashMap;
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
28 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
29 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
30 TokenUsage,
31};
32use postage::stream::Stream as _;
33use project::{
34 Project,
35 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
36};
37use prompt_store::{ModelContext, PromptBuilder};
38use proto::Plan;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use settings::Settings;
42use std::{
43 io::Write,
44 ops::Range,
45 sync::Arc,
46 time::{Duration, Instant},
47};
48use thiserror::Error;
49use util::{ResultExt as _, debug_panic, post_inc};
50use uuid::Uuid;
51use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
52
53const MAX_RETRY_ATTEMPTS: u8 = 3;
54const BASE_RETRY_DELAY_SECS: u64 = 5;
55
56#[derive(
57 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
58)]
59pub struct ThreadId(Arc<str>);
60
61impl ThreadId {
62 pub fn new() -> Self {
63 Self(Uuid::new_v4().to_string().into())
64 }
65}
66
67impl std::fmt::Display for ThreadId {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73impl From<&str> for ThreadId {
74 fn from(value: &str) -> Self {
75 Self(value.into())
76 }
77}
78
79/// The ID of the user prompt that initiated a request.
80///
81/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
82#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
83pub struct PromptId(Arc<str>);
84
85impl PromptId {
86 pub fn new() -> Self {
87 Self(Uuid::new_v4().to_string().into())
88 }
89}
90
91impl std::fmt::Display for PromptId {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{}", self.0)
94 }
95}
96
97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
98pub struct MessageId(pub(crate) usize);
99
100impl MessageId {
101 fn post_inc(&mut self) -> Self {
102 Self(post_inc(&mut self.0))
103 }
104
105 pub fn as_usize(&self) -> usize {
106 self.0
107 }
108}
109
110/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
111#[derive(Clone, Debug)]
112pub struct MessageCrease {
113 pub range: Range<usize>,
114 pub icon_path: SharedString,
115 pub label: SharedString,
116 /// None for a deserialized message, Some otherwise.
117 pub context: Option<AgentContextHandle>,
118}
119
120/// A message in a [`Thread`].
121#[derive(Debug, Clone)]
122pub struct Message {
123 pub id: MessageId,
124 pub role: Role,
125 pub segments: Vec<MessageSegment>,
126 pub loaded_context: LoadedContext,
127 pub creases: Vec<MessageCrease>,
128 pub is_hidden: bool,
129 pub ui_only: bool,
130}
131
132impl Message {
133 /// Returns whether the message contains any meaningful text that should be displayed
134 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
135 pub fn should_display_content(&self) -> bool {
136 self.segments.iter().all(|segment| segment.should_display())
137 }
138
139 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
140 if let Some(MessageSegment::Thinking {
141 text: segment,
142 signature: current_signature,
143 }) = self.segments.last_mut()
144 {
145 if let Some(signature) = signature {
146 *current_signature = Some(signature);
147 }
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Thinking {
151 text: text.to_string(),
152 signature,
153 });
154 }
155 }
156
157 pub fn push_redacted_thinking(&mut self, data: String) {
158 self.segments.push(MessageSegment::RedactedThinking(data));
159 }
160
161 pub fn push_text(&mut self, text: &str) {
162 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
163 segment.push_str(text);
164 } else {
165 self.segments.push(MessageSegment::Text(text.to_string()));
166 }
167 }
168
169 pub fn to_string(&self) -> String {
170 let mut result = String::new();
171
172 if !self.loaded_context.text.is_empty() {
173 result.push_str(&self.loaded_context.text);
174 }
175
176 for segment in &self.segments {
177 match segment {
178 MessageSegment::Text(text) => result.push_str(text),
179 MessageSegment::Thinking { text, .. } => {
180 result.push_str("<think>\n");
181 result.push_str(text);
182 result.push_str("\n</think>");
183 }
184 MessageSegment::RedactedThinking(_) => {}
185 }
186 }
187
188 result
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub enum MessageSegment {
194 Text(String),
195 Thinking {
196 text: String,
197 signature: Option<String>,
198 },
199 RedactedThinking(String),
200}
201
202impl MessageSegment {
203 pub fn should_display(&self) -> bool {
204 match self {
205 Self::Text(text) => text.is_empty(),
206 Self::Thinking { text, .. } => text.is_empty(),
207 Self::RedactedThinking(_) => false,
208 }
209 }
210
211 pub fn text(&self) -> Option<&str> {
212 match self {
213 MessageSegment::Text(text) => Some(text),
214 _ => None,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
220pub struct ProjectSnapshot {
221 pub worktree_snapshots: Vec<WorktreeSnapshot>,
222 pub unsaved_buffer_paths: Vec<String>,
223 pub timestamp: DateTime<Utc>,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227pub struct WorktreeSnapshot {
228 pub worktree_path: String,
229 pub git_state: Option<GitState>,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
233pub struct GitState {
234 pub remote_url: Option<String>,
235 pub head_sha: Option<String>,
236 pub current_branch: Option<String>,
237 pub diff: Option<String>,
238}
239
240#[derive(Clone, Debug)]
241pub struct ThreadCheckpoint {
242 message_id: MessageId,
243 git_checkpoint: GitStoreCheckpoint,
244}
245
246#[derive(Copy, Clone, Debug, PartialEq, Eq)]
247pub enum ThreadFeedback {
248 Positive,
249 Negative,
250}
251
252pub enum LastRestoreCheckpoint {
253 Pending {
254 message_id: MessageId,
255 },
256 Error {
257 message_id: MessageId,
258 error: String,
259 },
260}
261
262impl LastRestoreCheckpoint {
263 pub fn message_id(&self) -> MessageId {
264 match self {
265 LastRestoreCheckpoint::Pending { message_id } => *message_id,
266 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
267 }
268 }
269}
270
271#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
272pub enum DetailedSummaryState {
273 #[default]
274 NotGenerated,
275 Generating {
276 message_id: MessageId,
277 },
278 Generated {
279 text: SharedString,
280 message_id: MessageId,
281 },
282}
283
284impl DetailedSummaryState {
285 fn text(&self) -> Option<SharedString> {
286 if let Self::Generated { text, .. } = self {
287 Some(text.clone())
288 } else {
289 None
290 }
291 }
292}
293
294#[derive(Default, Debug)]
295pub struct TotalTokenUsage {
296 pub total: u64,
297 pub max: u64,
298}
299
300impl TotalTokenUsage {
301 pub fn ratio(&self) -> TokenUsageRatio {
302 #[cfg(debug_assertions)]
303 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
304 .unwrap_or("0.8".to_string())
305 .parse()
306 .unwrap();
307 #[cfg(not(debug_assertions))]
308 let warning_threshold: f32 = 0.8;
309
310 // When the maximum is unknown because there is no selected model,
311 // avoid showing the token limit warning.
312 if self.max == 0 {
313 TokenUsageRatio::Normal
314 } else if self.total >= self.max {
315 TokenUsageRatio::Exceeded
316 } else if self.total as f32 / self.max as f32 >= warning_threshold {
317 TokenUsageRatio::Warning
318 } else {
319 TokenUsageRatio::Normal
320 }
321 }
322
323 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
324 TotalTokenUsage {
325 total: self.total + tokens,
326 max: self.max,
327 }
328 }
329}
330
331#[derive(Debug, Default, PartialEq, Eq)]
332pub enum TokenUsageRatio {
333 #[default]
334 Normal,
335 Warning,
336 Exceeded,
337}
338
339#[derive(Debug, Clone, Copy)]
340pub enum QueueState {
341 Sending,
342 Queued { position: usize },
343 Started,
344}
345
346/// A thread of conversation with the LLM.
347pub struct Thread {
348 id: ThreadId,
349 updated_at: DateTime<Utc>,
350 summary: ThreadSummary,
351 pending_summary: Task<Option<()>>,
352 detailed_summary_task: Task<Option<()>>,
353 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
354 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
355 completion_mode: agent_settings::CompletionMode,
356 messages: Vec<Message>,
357 next_message_id: MessageId,
358 last_prompt_id: PromptId,
359 project_context: SharedProjectContext,
360 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
361 completion_count: usize,
362 pending_completions: Vec<PendingCompletion>,
363 project: Entity<Project>,
364 prompt_builder: Arc<PromptBuilder>,
365 tools: Entity<ToolWorkingSet>,
366 tool_use: ToolUseState,
367 action_log: Entity<ActionLog>,
368 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
369 pending_checkpoint: Option<ThreadCheckpoint>,
370 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
371 request_token_usage: Vec<TokenUsage>,
372 cumulative_token_usage: TokenUsage,
373 exceeded_window_error: Option<ExceededWindowError>,
374 tool_use_limit_reached: bool,
375 feedback: Option<ThreadFeedback>,
376 retry_state: Option<RetryState>,
377 message_feedback: HashMap<MessageId, ThreadFeedback>,
378 last_auto_capture_at: Option<Instant>,
379 last_received_chunk_at: Option<Instant>,
380 request_callback: Option<
381 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
382 >,
383 remaining_turns: u32,
384 configured_model: Option<ConfiguredModel>,
385 profile: AgentProfile,
386}
387
388#[derive(Clone, Debug)]
389struct RetryState {
390 attempt: u8,
391 max_attempts: u8,
392 intent: CompletionIntent,
393}
394
395#[derive(Clone, Debug, PartialEq, Eq)]
396pub enum ThreadSummary {
397 Pending,
398 Generating,
399 Ready(SharedString),
400 Error,
401}
402
403impl ThreadSummary {
404 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
405
406 pub fn or_default(&self) -> SharedString {
407 self.unwrap_or(Self::DEFAULT)
408 }
409
410 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
411 self.ready().unwrap_or_else(|| message.into())
412 }
413
414 pub fn ready(&self) -> Option<SharedString> {
415 match self {
416 ThreadSummary::Ready(summary) => Some(summary.clone()),
417 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
418 }
419 }
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
423pub struct ExceededWindowError {
424 /// Model used when last message exceeded context window
425 model_id: LanguageModelId,
426 /// Token count including last message
427 token_count: u64,
428}
429
430impl Thread {
431 pub fn new(
432 project: Entity<Project>,
433 tools: Entity<ToolWorkingSet>,
434 prompt_builder: Arc<PromptBuilder>,
435 system_prompt: SharedProjectContext,
436 cx: &mut Context<Self>,
437 ) -> Self {
438 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
439 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
440 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
441
442 Self {
443 id: ThreadId::new(),
444 updated_at: Utc::now(),
445 summary: ThreadSummary::Pending,
446 pending_summary: Task::ready(None),
447 detailed_summary_task: Task::ready(None),
448 detailed_summary_tx,
449 detailed_summary_rx,
450 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
451 messages: Vec::new(),
452 next_message_id: MessageId(0),
453 last_prompt_id: PromptId::new(),
454 project_context: system_prompt,
455 checkpoints_by_message: HashMap::default(),
456 completion_count: 0,
457 pending_completions: Vec::new(),
458 project: project.clone(),
459 prompt_builder,
460 tools: tools.clone(),
461 last_restore_checkpoint: None,
462 pending_checkpoint: None,
463 tool_use: ToolUseState::new(tools.clone()),
464 action_log: cx.new(|_| ActionLog::new(project.clone())),
465 initial_project_snapshot: {
466 let project_snapshot = Self::project_snapshot(project, cx);
467 cx.foreground_executor()
468 .spawn(async move { Some(project_snapshot.await) })
469 .shared()
470 },
471 request_token_usage: Vec::new(),
472 cumulative_token_usage: TokenUsage::default(),
473 exceeded_window_error: None,
474 tool_use_limit_reached: false,
475 feedback: None,
476 retry_state: None,
477 message_feedback: HashMap::default(),
478 last_auto_capture_at: None,
479 last_received_chunk_at: None,
480 request_callback: None,
481 remaining_turns: u32::MAX,
482 configured_model,
483 profile: AgentProfile::new(profile_id, tools),
484 }
485 }
486
487 pub fn deserialize(
488 id: ThreadId,
489 serialized: SerializedThread,
490 project: Entity<Project>,
491 tools: Entity<ToolWorkingSet>,
492 prompt_builder: Arc<PromptBuilder>,
493 project_context: SharedProjectContext,
494 window: Option<&mut Window>, // None in headless mode
495 cx: &mut Context<Self>,
496 ) -> Self {
497 let next_message_id = MessageId(
498 serialized
499 .messages
500 .last()
501 .map(|message| message.id.0 + 1)
502 .unwrap_or(0),
503 );
504 let tool_use = ToolUseState::from_serialized_messages(
505 tools.clone(),
506 &serialized.messages,
507 project.clone(),
508 window,
509 cx,
510 );
511 let (detailed_summary_tx, detailed_summary_rx) =
512 postage::watch::channel_with(serialized.detailed_summary_state);
513
514 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
515 serialized
516 .model
517 .and_then(|model| {
518 let model = SelectedModel {
519 provider: model.provider.clone().into(),
520 model: model.model.clone().into(),
521 };
522 registry.select_model(&model, cx)
523 })
524 .or_else(|| registry.default_model())
525 });
526
527 let completion_mode = serialized
528 .completion_mode
529 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
530 let profile_id = serialized
531 .profile
532 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
533
534 Self {
535 id,
536 updated_at: serialized.updated_at,
537 summary: ThreadSummary::Ready(serialized.summary),
538 pending_summary: Task::ready(None),
539 detailed_summary_task: Task::ready(None),
540 detailed_summary_tx,
541 detailed_summary_rx,
542 completion_mode,
543 retry_state: None,
544 messages: serialized
545 .messages
546 .into_iter()
547 .map(|message| Message {
548 id: message.id,
549 role: message.role,
550 segments: message
551 .segments
552 .into_iter()
553 .map(|segment| match segment {
554 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
555 SerializedMessageSegment::Thinking { text, signature } => {
556 MessageSegment::Thinking { text, signature }
557 }
558 SerializedMessageSegment::RedactedThinking { data } => {
559 MessageSegment::RedactedThinking(data)
560 }
561 })
562 .collect(),
563 loaded_context: LoadedContext {
564 contexts: Vec::new(),
565 text: message.context,
566 images: Vec::new(),
567 },
568 creases: message
569 .creases
570 .into_iter()
571 .map(|crease| MessageCrease {
572 range: crease.start..crease.end,
573 icon_path: crease.icon_path,
574 label: crease.label,
575 context: None,
576 })
577 .collect(),
578 is_hidden: message.is_hidden,
579 ui_only: false, // UI-only messages are not persisted
580 })
581 .collect(),
582 next_message_id,
583 last_prompt_id: PromptId::new(),
584 project_context,
585 checkpoints_by_message: HashMap::default(),
586 completion_count: 0,
587 pending_completions: Vec::new(),
588 last_restore_checkpoint: None,
589 pending_checkpoint: None,
590 project: project.clone(),
591 prompt_builder,
592 tools: tools.clone(),
593 tool_use,
594 action_log: cx.new(|_| ActionLog::new(project)),
595 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
596 request_token_usage: serialized.request_token_usage,
597 cumulative_token_usage: serialized.cumulative_token_usage,
598 exceeded_window_error: None,
599 tool_use_limit_reached: serialized.tool_use_limit_reached,
600 feedback: None,
601 message_feedback: HashMap::default(),
602 last_auto_capture_at: None,
603 last_received_chunk_at: None,
604 request_callback: None,
605 remaining_turns: u32::MAX,
606 configured_model,
607 profile: AgentProfile::new(profile_id, tools),
608 }
609 }
610
611 pub fn set_request_callback(
612 &mut self,
613 callback: impl 'static
614 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
615 ) {
616 self.request_callback = Some(Box::new(callback));
617 }
618
619 pub fn id(&self) -> &ThreadId {
620 &self.id
621 }
622
623 pub fn profile(&self) -> &AgentProfile {
624 &self.profile
625 }
626
627 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
628 if &id != self.profile.id() {
629 self.profile = AgentProfile::new(id, self.tools.clone());
630 cx.emit(ThreadEvent::ProfileChanged);
631 }
632 }
633
634 pub fn is_empty(&self) -> bool {
635 self.messages.is_empty()
636 }
637
638 pub fn updated_at(&self) -> DateTime<Utc> {
639 self.updated_at
640 }
641
642 pub fn touch_updated_at(&mut self) {
643 self.updated_at = Utc::now();
644 }
645
646 pub fn advance_prompt_id(&mut self) {
647 self.last_prompt_id = PromptId::new();
648 }
649
650 pub fn project_context(&self) -> SharedProjectContext {
651 self.project_context.clone()
652 }
653
654 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
655 if self.configured_model.is_none() {
656 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
657 }
658 self.configured_model.clone()
659 }
660
661 pub fn configured_model(&self) -> Option<ConfiguredModel> {
662 self.configured_model.clone()
663 }
664
665 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
666 self.configured_model = model;
667 cx.notify();
668 }
669
670 pub fn summary(&self) -> &ThreadSummary {
671 &self.summary
672 }
673
674 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
675 let current_summary = match &self.summary {
676 ThreadSummary::Pending | ThreadSummary::Generating => return,
677 ThreadSummary::Ready(summary) => summary,
678 ThreadSummary::Error => &ThreadSummary::DEFAULT,
679 };
680
681 let mut new_summary = new_summary.into();
682
683 if new_summary.is_empty() {
684 new_summary = ThreadSummary::DEFAULT;
685 }
686
687 if current_summary != &new_summary {
688 self.summary = ThreadSummary::Ready(new_summary);
689 cx.emit(ThreadEvent::SummaryChanged);
690 }
691 }
692
693 pub fn completion_mode(&self) -> CompletionMode {
694 self.completion_mode
695 }
696
697 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
698 self.completion_mode = mode;
699 }
700
701 pub fn message(&self, id: MessageId) -> Option<&Message> {
702 let index = self
703 .messages
704 .binary_search_by(|message| message.id.cmp(&id))
705 .ok()?;
706
707 self.messages.get(index)
708 }
709
710 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
711 self.messages.iter()
712 }
713
714 pub fn is_generating(&self) -> bool {
715 !self.pending_completions.is_empty() || !self.all_tools_finished()
716 }
717
718 /// Indicates whether streaming of language model events is stale.
719 /// When `is_generating()` is false, this method returns `None`.
720 pub fn is_generation_stale(&self) -> Option<bool> {
721 const STALE_THRESHOLD: u128 = 250;
722
723 self.last_received_chunk_at
724 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
725 }
726
727 fn received_chunk(&mut self) {
728 self.last_received_chunk_at = Some(Instant::now());
729 }
730
731 pub fn queue_state(&self) -> Option<QueueState> {
732 self.pending_completions
733 .first()
734 .map(|pending_completion| pending_completion.queue_state)
735 }
736
737 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
738 &self.tools
739 }
740
741 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
742 self.tool_use
743 .pending_tool_uses()
744 .into_iter()
745 .find(|tool_use| &tool_use.id == id)
746 }
747
748 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
749 self.tool_use
750 .pending_tool_uses()
751 .into_iter()
752 .filter(|tool_use| tool_use.status.needs_confirmation())
753 }
754
755 pub fn has_pending_tool_uses(&self) -> bool {
756 !self.tool_use.pending_tool_uses().is_empty()
757 }
758
759 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
760 self.checkpoints_by_message.get(&id).cloned()
761 }
762
763 pub fn restore_checkpoint(
764 &mut self,
765 checkpoint: ThreadCheckpoint,
766 cx: &mut Context<Self>,
767 ) -> Task<Result<()>> {
768 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
769 message_id: checkpoint.message_id,
770 });
771 cx.emit(ThreadEvent::CheckpointChanged);
772 cx.notify();
773
774 let git_store = self.project().read(cx).git_store().clone();
775 let restore = git_store.update(cx, |git_store, cx| {
776 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
777 });
778
779 cx.spawn(async move |this, cx| {
780 let result = restore.await;
781 this.update(cx, |this, cx| {
782 if let Err(err) = result.as_ref() {
783 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
784 message_id: checkpoint.message_id,
785 error: err.to_string(),
786 });
787 } else {
788 this.truncate(checkpoint.message_id, cx);
789 this.last_restore_checkpoint = None;
790 }
791 this.pending_checkpoint = None;
792 cx.emit(ThreadEvent::CheckpointChanged);
793 cx.notify();
794 })?;
795 result
796 })
797 }
798
799 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
800 let pending_checkpoint = if self.is_generating() {
801 return;
802 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
803 checkpoint
804 } else {
805 return;
806 };
807
808 self.finalize_checkpoint(pending_checkpoint, cx);
809 }
810
811 fn finalize_checkpoint(
812 &mut self,
813 pending_checkpoint: ThreadCheckpoint,
814 cx: &mut Context<Self>,
815 ) {
816 let git_store = self.project.read(cx).git_store().clone();
817 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
818 cx.spawn(async move |this, cx| match final_checkpoint.await {
819 Ok(final_checkpoint) => {
820 let equal = git_store
821 .update(cx, |store, cx| {
822 store.compare_checkpoints(
823 pending_checkpoint.git_checkpoint.clone(),
824 final_checkpoint.clone(),
825 cx,
826 )
827 })?
828 .await
829 .unwrap_or(false);
830
831 if !equal {
832 this.update(cx, |this, cx| {
833 this.insert_checkpoint(pending_checkpoint, cx)
834 })?;
835 }
836
837 Ok(())
838 }
839 Err(_) => this.update(cx, |this, cx| {
840 this.insert_checkpoint(pending_checkpoint, cx)
841 }),
842 })
843 .detach();
844 }
845
846 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
847 self.checkpoints_by_message
848 .insert(checkpoint.message_id, checkpoint);
849 cx.emit(ThreadEvent::CheckpointChanged);
850 cx.notify();
851 }
852
853 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
854 self.last_restore_checkpoint.as_ref()
855 }
856
857 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
858 let Some(message_ix) = self
859 .messages
860 .iter()
861 .rposition(|message| message.id == message_id)
862 else {
863 return;
864 };
865 for deleted_message in self.messages.drain(message_ix..) {
866 self.checkpoints_by_message.remove(&deleted_message.id);
867 }
868 cx.notify();
869 }
870
871 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
872 self.messages
873 .iter()
874 .find(|message| message.id == id)
875 .into_iter()
876 .flat_map(|message| message.loaded_context.contexts.iter())
877 }
878
879 pub fn is_turn_end(&self, ix: usize) -> bool {
880 if self.messages.is_empty() {
881 return false;
882 }
883
884 if !self.is_generating() && ix == self.messages.len() - 1 {
885 return true;
886 }
887
888 let Some(message) = self.messages.get(ix) else {
889 return false;
890 };
891
892 if message.role != Role::Assistant {
893 return false;
894 }
895
896 self.messages
897 .get(ix + 1)
898 .and_then(|message| {
899 self.message(message.id)
900 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
901 })
902 .unwrap_or(false)
903 }
904
905 pub fn tool_use_limit_reached(&self) -> bool {
906 self.tool_use_limit_reached
907 }
908
909 /// Returns whether all of the tool uses have finished running.
910 pub fn all_tools_finished(&self) -> bool {
911 // If the only pending tool uses left are the ones with errors, then
912 // that means that we've finished running all of the pending tools.
913 self.tool_use
914 .pending_tool_uses()
915 .iter()
916 .all(|pending_tool_use| pending_tool_use.status.is_error())
917 }
918
919 /// Returns whether any pending tool uses may perform edits
920 pub fn has_pending_edit_tool_uses(&self) -> bool {
921 self.tool_use
922 .pending_tool_uses()
923 .iter()
924 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
925 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
926 }
927
928 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
929 self.tool_use.tool_uses_for_message(id, cx)
930 }
931
932 pub fn tool_results_for_message(
933 &self,
934 assistant_message_id: MessageId,
935 ) -> Vec<&LanguageModelToolResult> {
936 self.tool_use.tool_results_for_message(assistant_message_id)
937 }
938
939 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
940 self.tool_use.tool_result(id)
941 }
942
943 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
944 match &self.tool_use.tool_result(id)?.content {
945 LanguageModelToolResultContent::Text(text) => Some(text),
946 LanguageModelToolResultContent::Image(_) => {
947 // TODO: We should display image
948 None
949 }
950 }
951 }
952
953 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
954 self.tool_use.tool_result_card(id).cloned()
955 }
956
957 /// Return tools that are both enabled and supported by the model
958 pub fn available_tools(
959 &self,
960 cx: &App,
961 model: Arc<dyn LanguageModel>,
962 ) -> Vec<LanguageModelRequestTool> {
963 if model.supports_tools() {
964 self.profile
965 .enabled_tools(cx)
966 .into_iter()
967 .filter_map(|(name, tool)| {
968 // Skip tools that cannot be supported
969 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
970 Some(LanguageModelRequestTool {
971 name: name.into(),
972 description: tool.description(),
973 input_schema,
974 })
975 })
976 .collect()
977 } else {
978 Vec::default()
979 }
980 }
981
982 pub fn insert_user_message(
983 &mut self,
984 text: impl Into<String>,
985 loaded_context: ContextLoadResult,
986 git_checkpoint: Option<GitStoreCheckpoint>,
987 creases: Vec<MessageCrease>,
988 cx: &mut Context<Self>,
989 ) -> MessageId {
990 if !loaded_context.referenced_buffers.is_empty() {
991 self.action_log.update(cx, |log, cx| {
992 for buffer in loaded_context.referenced_buffers {
993 log.buffer_read(buffer, cx);
994 }
995 });
996 }
997
998 let message_id = self.insert_message(
999 Role::User,
1000 vec![MessageSegment::Text(text.into())],
1001 loaded_context.loaded_context,
1002 creases,
1003 false,
1004 cx,
1005 );
1006
1007 if let Some(git_checkpoint) = git_checkpoint {
1008 self.pending_checkpoint = Some(ThreadCheckpoint {
1009 message_id,
1010 git_checkpoint,
1011 });
1012 }
1013
1014 self.auto_capture_telemetry(cx);
1015
1016 message_id
1017 }
1018
1019 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1020 let id = self.insert_message(
1021 Role::User,
1022 vec![MessageSegment::Text("Continue where you left off".into())],
1023 LoadedContext::default(),
1024 vec![],
1025 true,
1026 cx,
1027 );
1028 self.pending_checkpoint = None;
1029
1030 id
1031 }
1032
1033 pub fn insert_assistant_message(
1034 &mut self,
1035 segments: Vec<MessageSegment>,
1036 cx: &mut Context<Self>,
1037 ) -> MessageId {
1038 self.insert_message(
1039 Role::Assistant,
1040 segments,
1041 LoadedContext::default(),
1042 Vec::new(),
1043 false,
1044 cx,
1045 )
1046 }
1047
1048 pub fn insert_message(
1049 &mut self,
1050 role: Role,
1051 segments: Vec<MessageSegment>,
1052 loaded_context: LoadedContext,
1053 creases: Vec<MessageCrease>,
1054 is_hidden: bool,
1055 cx: &mut Context<Self>,
1056 ) -> MessageId {
1057 let id = self.next_message_id.post_inc();
1058 self.messages.push(Message {
1059 id,
1060 role,
1061 segments,
1062 loaded_context,
1063 creases,
1064 is_hidden,
1065 ui_only: false,
1066 });
1067 self.touch_updated_at();
1068 cx.emit(ThreadEvent::MessageAdded(id));
1069 id
1070 }
1071
1072 pub fn edit_message(
1073 &mut self,
1074 id: MessageId,
1075 new_role: Role,
1076 new_segments: Vec<MessageSegment>,
1077 creases: Vec<MessageCrease>,
1078 loaded_context: Option<LoadedContext>,
1079 checkpoint: Option<GitStoreCheckpoint>,
1080 cx: &mut Context<Self>,
1081 ) -> bool {
1082 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1083 return false;
1084 };
1085 message.role = new_role;
1086 message.segments = new_segments;
1087 message.creases = creases;
1088 if let Some(context) = loaded_context {
1089 message.loaded_context = context;
1090 }
1091 if let Some(git_checkpoint) = checkpoint {
1092 self.checkpoints_by_message.insert(
1093 id,
1094 ThreadCheckpoint {
1095 message_id: id,
1096 git_checkpoint,
1097 },
1098 );
1099 }
1100 self.touch_updated_at();
1101 cx.emit(ThreadEvent::MessageEdited(id));
1102 true
1103 }
1104
1105 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1106 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1107 return false;
1108 };
1109 self.messages.remove(index);
1110 self.touch_updated_at();
1111 cx.emit(ThreadEvent::MessageDeleted(id));
1112 true
1113 }
1114
1115 /// Returns the representation of this [`Thread`] in a textual form.
1116 ///
1117 /// This is the representation we use when attaching a thread as context to another thread.
1118 pub fn text(&self) -> String {
1119 let mut text = String::new();
1120
1121 for message in &self.messages {
1122 text.push_str(match message.role {
1123 language_model::Role::User => "User:",
1124 language_model::Role::Assistant => "Agent:",
1125 language_model::Role::System => "System:",
1126 });
1127 text.push('\n');
1128
1129 for segment in &message.segments {
1130 match segment {
1131 MessageSegment::Text(content) => text.push_str(content),
1132 MessageSegment::Thinking { text: content, .. } => {
1133 text.push_str(&format!("<think>{}</think>", content))
1134 }
1135 MessageSegment::RedactedThinking(_) => {}
1136 }
1137 }
1138 text.push('\n');
1139 }
1140
1141 text
1142 }
1143
1144 /// Serializes this thread into a format for storage or telemetry.
1145 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1146 let initial_project_snapshot = self.initial_project_snapshot.clone();
1147 cx.spawn(async move |this, cx| {
1148 let initial_project_snapshot = initial_project_snapshot.await;
1149 this.read_with(cx, |this, cx| SerializedThread {
1150 version: SerializedThread::VERSION.to_string(),
1151 summary: this.summary().or_default(),
1152 updated_at: this.updated_at(),
1153 messages: this
1154 .messages()
1155 .filter(|message| !message.ui_only)
1156 .map(|message| SerializedMessage {
1157 id: message.id,
1158 role: message.role,
1159 segments: message
1160 .segments
1161 .iter()
1162 .map(|segment| match segment {
1163 MessageSegment::Text(text) => {
1164 SerializedMessageSegment::Text { text: text.clone() }
1165 }
1166 MessageSegment::Thinking { text, signature } => {
1167 SerializedMessageSegment::Thinking {
1168 text: text.clone(),
1169 signature: signature.clone(),
1170 }
1171 }
1172 MessageSegment::RedactedThinking(data) => {
1173 SerializedMessageSegment::RedactedThinking {
1174 data: data.clone(),
1175 }
1176 }
1177 })
1178 .collect(),
1179 tool_uses: this
1180 .tool_uses_for_message(message.id, cx)
1181 .into_iter()
1182 .map(|tool_use| SerializedToolUse {
1183 id: tool_use.id,
1184 name: tool_use.name,
1185 input: tool_use.input,
1186 })
1187 .collect(),
1188 tool_results: this
1189 .tool_results_for_message(message.id)
1190 .into_iter()
1191 .map(|tool_result| SerializedToolResult {
1192 tool_use_id: tool_result.tool_use_id.clone(),
1193 is_error: tool_result.is_error,
1194 content: tool_result.content.clone(),
1195 output: tool_result.output.clone(),
1196 })
1197 .collect(),
1198 context: message.loaded_context.text.clone(),
1199 creases: message
1200 .creases
1201 .iter()
1202 .map(|crease| SerializedCrease {
1203 start: crease.range.start,
1204 end: crease.range.end,
1205 icon_path: crease.icon_path.clone(),
1206 label: crease.label.clone(),
1207 })
1208 .collect(),
1209 is_hidden: message.is_hidden,
1210 })
1211 .collect(),
1212 initial_project_snapshot,
1213 cumulative_token_usage: this.cumulative_token_usage,
1214 request_token_usage: this.request_token_usage.clone(),
1215 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1216 exceeded_window_error: this.exceeded_window_error.clone(),
1217 model: this
1218 .configured_model
1219 .as_ref()
1220 .map(|model| SerializedLanguageModel {
1221 provider: model.provider.id().0.to_string(),
1222 model: model.model.id().0.to_string(),
1223 }),
1224 completion_mode: Some(this.completion_mode),
1225 tool_use_limit_reached: this.tool_use_limit_reached,
1226 profile: Some(this.profile.id().clone()),
1227 })
1228 })
1229 }
1230
1231 pub fn remaining_turns(&self) -> u32 {
1232 self.remaining_turns
1233 }
1234
1235 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1236 self.remaining_turns = remaining_turns;
1237 }
1238
1239 pub fn send_to_model(
1240 &mut self,
1241 model: Arc<dyn LanguageModel>,
1242 intent: CompletionIntent,
1243 window: Option<AnyWindowHandle>,
1244 cx: &mut Context<Self>,
1245 ) {
1246 if self.remaining_turns == 0 {
1247 return;
1248 }
1249
1250 self.remaining_turns -= 1;
1251
1252 self.flush_notifications(model.clone(), intent, cx);
1253
1254 let request = self.to_completion_request(model.clone(), intent, cx);
1255
1256 self.stream_completion(request, model, intent, window, cx);
1257 }
1258
1259 pub fn used_tools_since_last_user_message(&self) -> bool {
1260 for message in self.messages.iter().rev() {
1261 if self.tool_use.message_has_tool_results(message.id) {
1262 return true;
1263 } else if message.role == Role::User {
1264 return false;
1265 }
1266 }
1267
1268 false
1269 }
1270
1271 pub fn to_completion_request(
1272 &self,
1273 model: Arc<dyn LanguageModel>,
1274 intent: CompletionIntent,
1275 cx: &mut Context<Self>,
1276 ) -> LanguageModelRequest {
1277 let mut request = LanguageModelRequest {
1278 thread_id: Some(self.id.to_string()),
1279 prompt_id: Some(self.last_prompt_id.to_string()),
1280 intent: Some(intent),
1281 mode: None,
1282 messages: vec![],
1283 tools: Vec::new(),
1284 tool_choice: None,
1285 stop: Vec::new(),
1286 temperature: AgentSettings::temperature_for_model(&model, cx),
1287 };
1288
1289 let available_tools = self.available_tools(cx, model.clone());
1290 let available_tool_names = available_tools
1291 .iter()
1292 .map(|tool| tool.name.clone())
1293 .collect();
1294
1295 let model_context = &ModelContext {
1296 available_tools: available_tool_names,
1297 };
1298
1299 if let Some(project_context) = self.project_context.borrow().as_ref() {
1300 match self
1301 .prompt_builder
1302 .generate_assistant_system_prompt(project_context, model_context)
1303 {
1304 Err(err) => {
1305 let message = format!("{err:?}").into();
1306 log::error!("{message}");
1307 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1308 header: "Error generating system prompt".into(),
1309 message,
1310 }));
1311 }
1312 Ok(system_prompt) => {
1313 request.messages.push(LanguageModelRequestMessage {
1314 role: Role::System,
1315 content: vec![MessageContent::Text(system_prompt)],
1316 cache: true,
1317 });
1318 }
1319 }
1320 } else {
1321 let message = "Context for system prompt unexpectedly not ready.".into();
1322 log::error!("{message}");
1323 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1324 header: "Error generating system prompt".into(),
1325 message,
1326 }));
1327 }
1328
1329 let mut message_ix_to_cache = None;
1330 for message in &self.messages {
1331 // ui_only messages are for the UI only, not for the model
1332 if message.ui_only {
1333 continue;
1334 }
1335
1336 let mut request_message = LanguageModelRequestMessage {
1337 role: message.role,
1338 content: Vec::new(),
1339 cache: false,
1340 };
1341
1342 message
1343 .loaded_context
1344 .add_to_request_message(&mut request_message);
1345
1346 for segment in &message.segments {
1347 match segment {
1348 MessageSegment::Text(text) => {
1349 let text = text.trim_end();
1350 if !text.is_empty() {
1351 request_message
1352 .content
1353 .push(MessageContent::Text(text.into()));
1354 }
1355 }
1356 MessageSegment::Thinking { text, signature } => {
1357 if !text.is_empty() {
1358 request_message.content.push(MessageContent::Thinking {
1359 text: text.into(),
1360 signature: signature.clone(),
1361 });
1362 }
1363 }
1364 MessageSegment::RedactedThinking(data) => {
1365 request_message
1366 .content
1367 .push(MessageContent::RedactedThinking(data.clone()));
1368 }
1369 };
1370 }
1371
1372 let mut cache_message = true;
1373 let mut tool_results_message = LanguageModelRequestMessage {
1374 role: Role::User,
1375 content: Vec::new(),
1376 cache: false,
1377 };
1378 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1379 if let Some(tool_result) = tool_result {
1380 request_message
1381 .content
1382 .push(MessageContent::ToolUse(tool_use.clone()));
1383 tool_results_message
1384 .content
1385 .push(MessageContent::ToolResult(LanguageModelToolResult {
1386 tool_use_id: tool_use.id.clone(),
1387 tool_name: tool_result.tool_name.clone(),
1388 is_error: tool_result.is_error,
1389 content: if tool_result.content.is_empty() {
1390 // Surprisingly, the API fails if we return an empty string here.
1391 // It thinks we are sending a tool use without a tool result.
1392 "<Tool returned an empty string>".into()
1393 } else {
1394 tool_result.content.clone()
1395 },
1396 output: None,
1397 }));
1398 } else {
1399 cache_message = false;
1400 log::debug!(
1401 "skipped tool use {:?} because it is still pending",
1402 tool_use
1403 );
1404 }
1405 }
1406
1407 if cache_message {
1408 message_ix_to_cache = Some(request.messages.len());
1409 }
1410 request.messages.push(request_message);
1411
1412 if !tool_results_message.content.is_empty() {
1413 if cache_message {
1414 message_ix_to_cache = Some(request.messages.len());
1415 }
1416 request.messages.push(tool_results_message);
1417 }
1418 }
1419
1420 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1421 if let Some(message_ix_to_cache) = message_ix_to_cache {
1422 request.messages[message_ix_to_cache].cache = true;
1423 }
1424
1425 request.tools = available_tools;
1426 request.mode = if model.supports_burn_mode() {
1427 Some(self.completion_mode.into())
1428 } else {
1429 Some(CompletionMode::Normal.into())
1430 };
1431
1432 request
1433 }
1434
1435 fn to_summarize_request(
1436 &self,
1437 model: &Arc<dyn LanguageModel>,
1438 intent: CompletionIntent,
1439 added_user_message: String,
1440 cx: &App,
1441 ) -> LanguageModelRequest {
1442 let mut request = LanguageModelRequest {
1443 thread_id: None,
1444 prompt_id: None,
1445 intent: Some(intent),
1446 mode: None,
1447 messages: vec![],
1448 tools: Vec::new(),
1449 tool_choice: None,
1450 stop: Vec::new(),
1451 temperature: AgentSettings::temperature_for_model(model, cx),
1452 };
1453
1454 for message in &self.messages {
1455 let mut request_message = LanguageModelRequestMessage {
1456 role: message.role,
1457 content: Vec::new(),
1458 cache: false,
1459 };
1460
1461 for segment in &message.segments {
1462 match segment {
1463 MessageSegment::Text(text) => request_message
1464 .content
1465 .push(MessageContent::Text(text.clone())),
1466 MessageSegment::Thinking { .. } => {}
1467 MessageSegment::RedactedThinking(_) => {}
1468 }
1469 }
1470
1471 if request_message.content.is_empty() {
1472 continue;
1473 }
1474
1475 request.messages.push(request_message);
1476 }
1477
1478 request.messages.push(LanguageModelRequestMessage {
1479 role: Role::User,
1480 content: vec![MessageContent::Text(added_user_message)],
1481 cache: false,
1482 });
1483
1484 request
1485 }
1486
1487 /// Insert auto-generated notifications (if any) to the thread
1488 fn flush_notifications(
1489 &mut self,
1490 model: Arc<dyn LanguageModel>,
1491 intent: CompletionIntent,
1492 cx: &mut Context<Self>,
1493 ) {
1494 match intent {
1495 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1496 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1497 cx.emit(ThreadEvent::ToolFinished {
1498 tool_use_id: pending_tool_use.id.clone(),
1499 pending_tool_use: Some(pending_tool_use),
1500 });
1501 }
1502 }
1503 CompletionIntent::ThreadSummarization
1504 | CompletionIntent::ThreadContextSummarization
1505 | CompletionIntent::CreateFile
1506 | CompletionIntent::EditFile
1507 | CompletionIntent::InlineAssist
1508 | CompletionIntent::TerminalInlineAssist
1509 | CompletionIntent::GenerateGitCommitMessage => {}
1510 };
1511 }
1512
1513 fn attach_tracked_files_state(
1514 &mut self,
1515 model: Arc<dyn LanguageModel>,
1516 cx: &mut App,
1517 ) -> Option<PendingToolUse> {
1518 let action_log = self.action_log.read(cx);
1519
1520 action_log.unnotified_stale_buffers(cx).next()?;
1521
1522 // Represent notification as a simulated `project_notifications` tool call
1523 let tool_name = Arc::from("project_notifications");
1524 let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1525 debug_panic!("`project_notifications` tool not found");
1526 return None;
1527 };
1528
1529 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1530 return None;
1531 }
1532
1533 let input = serde_json::json!({});
1534 let request = Arc::new(LanguageModelRequest::default()); // unused
1535 let window = None;
1536 let tool_result = tool.run(
1537 input,
1538 request,
1539 self.project.clone(),
1540 self.action_log.clone(),
1541 model.clone(),
1542 window,
1543 cx,
1544 );
1545
1546 let tool_use_id =
1547 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1548
1549 let tool_use = LanguageModelToolUse {
1550 id: tool_use_id.clone(),
1551 name: tool_name.clone(),
1552 raw_input: "{}".to_string(),
1553 input: serde_json::json!({}),
1554 is_input_complete: true,
1555 };
1556
1557 let tool_output = cx.background_executor().block(tool_result.output);
1558
1559 // Attach a project_notification tool call to the latest existing
1560 // Assistant message. We cannot create a new Assistant message
1561 // because thinking models require a `thinking` block that we
1562 // cannot mock. We cannot send a notification as a normal
1563 // (non-tool-use) User message because this distracts Agent
1564 // too much.
1565 let tool_message_id = self
1566 .messages
1567 .iter()
1568 .enumerate()
1569 .rfind(|(_, message)| message.role == Role::Assistant)
1570 .map(|(_, message)| message.id)?;
1571
1572 let tool_use_metadata = ToolUseMetadata {
1573 model: model.clone(),
1574 thread_id: self.id.clone(),
1575 prompt_id: self.last_prompt_id.clone(),
1576 };
1577
1578 self.tool_use
1579 .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1580
1581 let pending_tool_use = self.tool_use.insert_tool_output(
1582 tool_use_id.clone(),
1583 tool_name,
1584 tool_output,
1585 self.configured_model.as_ref(),
1586 self.completion_mode,
1587 );
1588
1589 pending_tool_use
1590 }
1591
1592 pub fn stream_completion(
1593 &mut self,
1594 request: LanguageModelRequest,
1595 model: Arc<dyn LanguageModel>,
1596 intent: CompletionIntent,
1597 window: Option<AnyWindowHandle>,
1598 cx: &mut Context<Self>,
1599 ) {
1600 self.tool_use_limit_reached = false;
1601
1602 let pending_completion_id = post_inc(&mut self.completion_count);
1603 let mut request_callback_parameters = if self.request_callback.is_some() {
1604 Some((request.clone(), Vec::new()))
1605 } else {
1606 None
1607 };
1608 let prompt_id = self.last_prompt_id.clone();
1609 let tool_use_metadata = ToolUseMetadata {
1610 model: model.clone(),
1611 thread_id: self.id.clone(),
1612 prompt_id: prompt_id.clone(),
1613 };
1614
1615 let completion_mode = request
1616 .mode
1617 .unwrap_or(zed_llm_client::CompletionMode::Normal);
1618
1619 self.last_received_chunk_at = Some(Instant::now());
1620
1621 let task = cx.spawn(async move |thread, cx| {
1622 let stream_completion_future = model.stream_completion(request, &cx);
1623 let initial_token_usage =
1624 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1625 let stream_completion = async {
1626 let mut events = stream_completion_future.await?;
1627
1628 let mut stop_reason = StopReason::EndTurn;
1629 let mut current_token_usage = TokenUsage::default();
1630
1631 thread
1632 .update(cx, |_thread, cx| {
1633 cx.emit(ThreadEvent::NewRequest);
1634 })
1635 .ok();
1636
1637 let mut request_assistant_message_id = None;
1638
1639 while let Some(event) = events.next().await {
1640 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1641 response_events
1642 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1643 }
1644
1645 thread.update(cx, |thread, cx| {
1646 match event? {
1647 LanguageModelCompletionEvent::StartMessage { .. } => {
1648 request_assistant_message_id =
1649 Some(thread.insert_assistant_message(
1650 vec![MessageSegment::Text(String::new())],
1651 cx,
1652 ));
1653 }
1654 LanguageModelCompletionEvent::Stop(reason) => {
1655 stop_reason = reason;
1656 }
1657 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1658 thread.update_token_usage_at_last_message(token_usage);
1659 thread.cumulative_token_usage = thread.cumulative_token_usage
1660 + token_usage
1661 - current_token_usage;
1662 current_token_usage = token_usage;
1663 }
1664 LanguageModelCompletionEvent::Text(chunk) => {
1665 thread.received_chunk();
1666
1667 cx.emit(ThreadEvent::ReceivedTextChunk);
1668 if let Some(last_message) = thread.messages.last_mut() {
1669 if last_message.role == Role::Assistant
1670 && !thread.tool_use.has_tool_results(last_message.id)
1671 {
1672 last_message.push_text(&chunk);
1673 cx.emit(ThreadEvent::StreamedAssistantText(
1674 last_message.id,
1675 chunk,
1676 ));
1677 } else {
1678 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1679 // of a new Assistant response.
1680 //
1681 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1682 // will result in duplicating the text of the chunk in the rendered Markdown.
1683 request_assistant_message_id =
1684 Some(thread.insert_assistant_message(
1685 vec![MessageSegment::Text(chunk.to_string())],
1686 cx,
1687 ));
1688 };
1689 }
1690 }
1691 LanguageModelCompletionEvent::Thinking {
1692 text: chunk,
1693 signature,
1694 } => {
1695 thread.received_chunk();
1696
1697 if let Some(last_message) = thread.messages.last_mut() {
1698 if last_message.role == Role::Assistant
1699 && !thread.tool_use.has_tool_results(last_message.id)
1700 {
1701 last_message.push_thinking(&chunk, signature);
1702 cx.emit(ThreadEvent::StreamedAssistantThinking(
1703 last_message.id,
1704 chunk,
1705 ));
1706 } else {
1707 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1708 // of a new Assistant response.
1709 //
1710 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1711 // will result in duplicating the text of the chunk in the rendered Markdown.
1712 request_assistant_message_id =
1713 Some(thread.insert_assistant_message(
1714 vec![MessageSegment::Thinking {
1715 text: chunk.to_string(),
1716 signature,
1717 }],
1718 cx,
1719 ));
1720 };
1721 }
1722 }
1723 LanguageModelCompletionEvent::RedactedThinking { data } => {
1724 thread.received_chunk();
1725
1726 if let Some(last_message) = thread.messages.last_mut() {
1727 if last_message.role == Role::Assistant
1728 && !thread.tool_use.has_tool_results(last_message.id)
1729 {
1730 last_message.push_redacted_thinking(data);
1731 } else {
1732 request_assistant_message_id =
1733 Some(thread.insert_assistant_message(
1734 vec![MessageSegment::RedactedThinking(data)],
1735 cx,
1736 ));
1737 };
1738 }
1739 }
1740 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1741 let last_assistant_message_id = request_assistant_message_id
1742 .unwrap_or_else(|| {
1743 let new_assistant_message_id =
1744 thread.insert_assistant_message(vec![], cx);
1745 request_assistant_message_id =
1746 Some(new_assistant_message_id);
1747 new_assistant_message_id
1748 });
1749
1750 let tool_use_id = tool_use.id.clone();
1751 let streamed_input = if tool_use.is_input_complete {
1752 None
1753 } else {
1754 Some((&tool_use.input).clone())
1755 };
1756
1757 let ui_text = thread.tool_use.request_tool_use(
1758 last_assistant_message_id,
1759 tool_use,
1760 tool_use_metadata.clone(),
1761 cx,
1762 );
1763
1764 if let Some(input) = streamed_input {
1765 cx.emit(ThreadEvent::StreamedToolUse {
1766 tool_use_id,
1767 ui_text,
1768 input,
1769 });
1770 }
1771 }
1772 LanguageModelCompletionEvent::ToolUseJsonParseError {
1773 id,
1774 tool_name,
1775 raw_input: invalid_input_json,
1776 json_parse_error,
1777 } => {
1778 thread.receive_invalid_tool_json(
1779 id,
1780 tool_name,
1781 invalid_input_json,
1782 json_parse_error,
1783 window,
1784 cx,
1785 );
1786 }
1787 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1788 if let Some(completion) = thread
1789 .pending_completions
1790 .iter_mut()
1791 .find(|completion| completion.id == pending_completion_id)
1792 {
1793 match status_update {
1794 CompletionRequestStatus::Queued { position } => {
1795 completion.queue_state =
1796 QueueState::Queued { position };
1797 }
1798 CompletionRequestStatus::Started => {
1799 completion.queue_state = QueueState::Started;
1800 }
1801 CompletionRequestStatus::Failed {
1802 code,
1803 message,
1804 request_id: _,
1805 retry_after,
1806 } => {
1807 return Err(
1808 LanguageModelCompletionError::from_cloud_failure(
1809 model.upstream_provider_name(),
1810 code,
1811 message,
1812 retry_after.map(Duration::from_secs_f64),
1813 ),
1814 );
1815 }
1816 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1817 thread.update_model_request_usage(
1818 amount as u32,
1819 limit,
1820 cx,
1821 );
1822 }
1823 CompletionRequestStatus::ToolUseLimitReached => {
1824 thread.tool_use_limit_reached = true;
1825 cx.emit(ThreadEvent::ToolUseLimitReached);
1826 }
1827 }
1828 }
1829 }
1830 }
1831
1832 thread.touch_updated_at();
1833 cx.emit(ThreadEvent::StreamedCompletion);
1834 cx.notify();
1835
1836 thread.auto_capture_telemetry(cx);
1837 Ok(())
1838 })??;
1839
1840 smol::future::yield_now().await;
1841 }
1842
1843 thread.update(cx, |thread, cx| {
1844 thread.last_received_chunk_at = None;
1845 thread
1846 .pending_completions
1847 .retain(|completion| completion.id != pending_completion_id);
1848
1849 // If there is a response without tool use, summarize the message. Otherwise,
1850 // allow two tool uses before summarizing.
1851 if matches!(thread.summary, ThreadSummary::Pending)
1852 && thread.messages.len() >= 2
1853 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1854 {
1855 thread.summarize(cx);
1856 }
1857 })?;
1858
1859 anyhow::Ok(stop_reason)
1860 };
1861
1862 let result = stream_completion.await;
1863 let mut retry_scheduled = false;
1864
1865 thread
1866 .update(cx, |thread, cx| {
1867 thread.finalize_pending_checkpoint(cx);
1868 match result.as_ref() {
1869 Ok(stop_reason) => {
1870 match stop_reason {
1871 StopReason::ToolUse => {
1872 let tool_uses =
1873 thread.use_pending_tools(window, model.clone(), cx);
1874 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1875 }
1876 StopReason::EndTurn | StopReason::MaxTokens => {
1877 thread.project.update(cx, |project, cx| {
1878 project.set_agent_location(None, cx);
1879 });
1880 }
1881 StopReason::Refusal => {
1882 thread.project.update(cx, |project, cx| {
1883 project.set_agent_location(None, cx);
1884 });
1885
1886 // Remove the turn that was refused.
1887 //
1888 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1889 {
1890 let mut messages_to_remove = Vec::new();
1891
1892 for (ix, message) in
1893 thread.messages.iter().enumerate().rev()
1894 {
1895 messages_to_remove.push(message.id);
1896
1897 if message.role == Role::User {
1898 if ix == 0 {
1899 break;
1900 }
1901
1902 if let Some(prev_message) =
1903 thread.messages.get(ix - 1)
1904 {
1905 if prev_message.role == Role::Assistant {
1906 break;
1907 }
1908 }
1909 }
1910 }
1911
1912 for message_id in messages_to_remove {
1913 thread.delete_message(message_id, cx);
1914 }
1915 }
1916
1917 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1918 header: "Language model refusal".into(),
1919 message:
1920 "Model refused to generate content for safety reasons."
1921 .into(),
1922 }));
1923 }
1924 }
1925
1926 // We successfully completed, so cancel any remaining retries.
1927 thread.retry_state = None;
1928 }
1929 Err(error) => {
1930 thread.project.update(cx, |project, cx| {
1931 project.set_agent_location(None, cx);
1932 });
1933
1934 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1935 let error_message = error
1936 .chain()
1937 .map(|err| err.to_string())
1938 .collect::<Vec<_>>()
1939 .join("\n");
1940 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1941 header: "Error interacting with language model".into(),
1942 message: SharedString::from(error_message.clone()),
1943 }));
1944 }
1945
1946 if error.is::<PaymentRequiredError>() {
1947 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1948 } else if let Some(error) =
1949 error.downcast_ref::<ModelRequestLimitReachedError>()
1950 {
1951 cx.emit(ThreadEvent::ShowError(
1952 ThreadError::ModelRequestLimitReached { plan: error.plan },
1953 ));
1954 } else if let Some(completion_error) =
1955 error.downcast_ref::<LanguageModelCompletionError>()
1956 {
1957 use LanguageModelCompletionError::*;
1958 match &completion_error {
1959 PromptTooLarge { tokens, .. } => {
1960 let tokens = tokens.unwrap_or_else(|| {
1961 // We didn't get an exact token count from the API, so fall back on our estimate.
1962 thread
1963 .total_token_usage()
1964 .map(|usage| usage.total)
1965 .unwrap_or(0)
1966 // We know the context window was exceeded in practice, so if our estimate was
1967 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1968 .max(
1969 model
1970 .max_token_count_for_mode(completion_mode)
1971 .saturating_add(1),
1972 )
1973 });
1974 thread.exceeded_window_error = Some(ExceededWindowError {
1975 model_id: model.id(),
1976 token_count: tokens,
1977 });
1978 cx.notify();
1979 }
1980 RateLimitExceeded {
1981 retry_after: Some(retry_after),
1982 ..
1983 }
1984 | ServerOverloaded {
1985 retry_after: Some(retry_after),
1986 ..
1987 } => {
1988 thread.handle_rate_limit_error(
1989 &completion_error,
1990 *retry_after,
1991 model.clone(),
1992 intent,
1993 window,
1994 cx,
1995 );
1996 retry_scheduled = true;
1997 }
1998 RateLimitExceeded { .. } | ServerOverloaded { .. } => {
1999 retry_scheduled = thread.handle_retryable_error(
2000 &completion_error,
2001 model.clone(),
2002 intent,
2003 window,
2004 cx,
2005 );
2006 if !retry_scheduled {
2007 emit_generic_error(error, cx);
2008 }
2009 }
2010 ApiInternalServerError { .. }
2011 | ApiReadResponseError { .. }
2012 | HttpSend { .. } => {
2013 retry_scheduled = thread.handle_retryable_error(
2014 &completion_error,
2015 model.clone(),
2016 intent,
2017 window,
2018 cx,
2019 );
2020 if !retry_scheduled {
2021 emit_generic_error(error, cx);
2022 }
2023 }
2024 NoApiKey { .. }
2025 | HttpResponseError { .. }
2026 | BadRequestFormat { .. }
2027 | AuthenticationError { .. }
2028 | PermissionError { .. }
2029 | ApiEndpointNotFound { .. }
2030 | SerializeRequest { .. }
2031 | BuildRequestBody { .. }
2032 | DeserializeResponse { .. }
2033 | Other { .. } => emit_generic_error(error, cx),
2034 }
2035 } else {
2036 emit_generic_error(error, cx);
2037 }
2038
2039 if !retry_scheduled {
2040 thread.cancel_last_completion(window, cx);
2041 }
2042 }
2043 }
2044
2045 if !retry_scheduled {
2046 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2047 }
2048
2049 if let Some((request_callback, (request, response_events))) = thread
2050 .request_callback
2051 .as_mut()
2052 .zip(request_callback_parameters.as_ref())
2053 {
2054 request_callback(request, response_events);
2055 }
2056
2057 thread.auto_capture_telemetry(cx);
2058
2059 if let Ok(initial_usage) = initial_token_usage {
2060 let usage = thread.cumulative_token_usage - initial_usage;
2061
2062 telemetry::event!(
2063 "Assistant Thread Completion",
2064 thread_id = thread.id().to_string(),
2065 prompt_id = prompt_id,
2066 model = model.telemetry_id(),
2067 model_provider = model.provider_id().to_string(),
2068 input_tokens = usage.input_tokens,
2069 output_tokens = usage.output_tokens,
2070 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2071 cache_read_input_tokens = usage.cache_read_input_tokens,
2072 );
2073 }
2074 })
2075 .ok();
2076 });
2077
2078 self.pending_completions.push(PendingCompletion {
2079 id: pending_completion_id,
2080 queue_state: QueueState::Sending,
2081 _task: task,
2082 });
2083 }
2084
2085 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2086 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2087 println!("No thread summary model");
2088 return;
2089 };
2090
2091 if !model.provider.is_authenticated(cx) {
2092 return;
2093 }
2094
2095 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2096
2097 let request = self.to_summarize_request(
2098 &model.model,
2099 CompletionIntent::ThreadSummarization,
2100 added_user_message.into(),
2101 cx,
2102 );
2103
2104 self.summary = ThreadSummary::Generating;
2105
2106 self.pending_summary = cx.spawn(async move |this, cx| {
2107 let result = async {
2108 let mut messages = model.model.stream_completion(request, &cx).await?;
2109
2110 let mut new_summary = String::new();
2111 while let Some(event) = messages.next().await {
2112 let Ok(event) = event else {
2113 continue;
2114 };
2115 let text = match event {
2116 LanguageModelCompletionEvent::Text(text) => text,
2117 LanguageModelCompletionEvent::StatusUpdate(
2118 CompletionRequestStatus::UsageUpdated { amount, limit },
2119 ) => {
2120 this.update(cx, |thread, cx| {
2121 thread.update_model_request_usage(amount as u32, limit, cx);
2122 })?;
2123 continue;
2124 }
2125 _ => continue,
2126 };
2127
2128 let mut lines = text.lines();
2129 new_summary.extend(lines.next());
2130
2131 // Stop if the LLM generated multiple lines.
2132 if lines.next().is_some() {
2133 break;
2134 }
2135 }
2136
2137 anyhow::Ok(new_summary)
2138 }
2139 .await;
2140
2141 this.update(cx, |this, cx| {
2142 match result {
2143 Ok(new_summary) => {
2144 if new_summary.is_empty() {
2145 this.summary = ThreadSummary::Error;
2146 } else {
2147 this.summary = ThreadSummary::Ready(new_summary.into());
2148 }
2149 }
2150 Err(err) => {
2151 this.summary = ThreadSummary::Error;
2152 log::error!("Failed to generate thread summary: {}", err);
2153 }
2154 }
2155 cx.emit(ThreadEvent::SummaryGenerated);
2156 })
2157 .log_err()?;
2158
2159 Some(())
2160 });
2161 }
2162
2163 fn handle_rate_limit_error(
2164 &mut self,
2165 error: &LanguageModelCompletionError,
2166 retry_after: Duration,
2167 model: Arc<dyn LanguageModel>,
2168 intent: CompletionIntent,
2169 window: Option<AnyWindowHandle>,
2170 cx: &mut Context<Self>,
2171 ) {
2172 // For rate limit errors, we only retry once with the specified duration
2173 let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2174 log::warn!(
2175 "Retrying completion request in {} seconds: {error:?}",
2176 retry_after.as_secs(),
2177 );
2178
2179 // Add a UI-only message instead of a regular message
2180 let id = self.next_message_id.post_inc();
2181 self.messages.push(Message {
2182 id,
2183 role: Role::System,
2184 segments: vec![MessageSegment::Text(retry_message)],
2185 loaded_context: LoadedContext::default(),
2186 creases: Vec::new(),
2187 is_hidden: false,
2188 ui_only: true,
2189 });
2190 cx.emit(ThreadEvent::MessageAdded(id));
2191 // Schedule the retry
2192 let thread_handle = cx.entity().downgrade();
2193
2194 cx.spawn(async move |_thread, cx| {
2195 cx.background_executor().timer(retry_after).await;
2196
2197 thread_handle
2198 .update(cx, |thread, cx| {
2199 // Retry the completion
2200 thread.send_to_model(model, intent, window, cx);
2201 })
2202 .log_err();
2203 })
2204 .detach();
2205 }
2206
2207 fn handle_retryable_error(
2208 &mut self,
2209 error: &LanguageModelCompletionError,
2210 model: Arc<dyn LanguageModel>,
2211 intent: CompletionIntent,
2212 window: Option<AnyWindowHandle>,
2213 cx: &mut Context<Self>,
2214 ) -> bool {
2215 self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2216 }
2217
2218 fn handle_retryable_error_with_delay(
2219 &mut self,
2220 error: &LanguageModelCompletionError,
2221 custom_delay: Option<Duration>,
2222 model: Arc<dyn LanguageModel>,
2223 intent: CompletionIntent,
2224 window: Option<AnyWindowHandle>,
2225 cx: &mut Context<Self>,
2226 ) -> bool {
2227 let retry_state = self.retry_state.get_or_insert(RetryState {
2228 attempt: 0,
2229 max_attempts: MAX_RETRY_ATTEMPTS,
2230 intent,
2231 });
2232
2233 retry_state.attempt += 1;
2234 let attempt = retry_state.attempt;
2235 let max_attempts = retry_state.max_attempts;
2236 let intent = retry_state.intent;
2237
2238 if attempt <= max_attempts {
2239 // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2240 let delay = if let Some(custom_delay) = custom_delay {
2241 custom_delay
2242 } else {
2243 let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2244 Duration::from_secs(delay_secs)
2245 };
2246
2247 // Add a transient message to inform the user
2248 let delay_secs = delay.as_secs();
2249 let retry_message = format!(
2250 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2251 in {delay_secs} seconds..."
2252 );
2253 log::warn!(
2254 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2255 in {delay_secs} seconds: {error:?}",
2256 );
2257
2258 // Add a UI-only message instead of a regular message
2259 let id = self.next_message_id.post_inc();
2260 self.messages.push(Message {
2261 id,
2262 role: Role::System,
2263 segments: vec![MessageSegment::Text(retry_message)],
2264 loaded_context: LoadedContext::default(),
2265 creases: Vec::new(),
2266 is_hidden: false,
2267 ui_only: true,
2268 });
2269 cx.emit(ThreadEvent::MessageAdded(id));
2270
2271 // Schedule the retry
2272 let thread_handle = cx.entity().downgrade();
2273
2274 cx.spawn(async move |_thread, cx| {
2275 cx.background_executor().timer(delay).await;
2276
2277 thread_handle
2278 .update(cx, |thread, cx| {
2279 // Retry the completion
2280 thread.send_to_model(model, intent, window, cx);
2281 })
2282 .log_err();
2283 })
2284 .detach();
2285
2286 true
2287 } else {
2288 // Max retries exceeded
2289 self.retry_state = None;
2290
2291 let notification_text = if max_attempts == 1 {
2292 "Failed after retrying.".into()
2293 } else {
2294 format!("Failed after retrying {} times.", max_attempts).into()
2295 };
2296
2297 // Stop generating since we're giving up on retrying.
2298 self.pending_completions.clear();
2299
2300 cx.emit(ThreadEvent::RetriesFailed {
2301 message: notification_text,
2302 });
2303
2304 false
2305 }
2306 }
2307
2308 pub fn start_generating_detailed_summary_if_needed(
2309 &mut self,
2310 thread_store: WeakEntity<ThreadStore>,
2311 cx: &mut Context<Self>,
2312 ) {
2313 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2314 return;
2315 };
2316
2317 match &*self.detailed_summary_rx.borrow() {
2318 DetailedSummaryState::Generating { message_id, .. }
2319 | DetailedSummaryState::Generated { message_id, .. }
2320 if *message_id == last_message_id =>
2321 {
2322 // Already up-to-date
2323 return;
2324 }
2325 _ => {}
2326 }
2327
2328 let Some(ConfiguredModel { model, provider }) =
2329 LanguageModelRegistry::read_global(cx).thread_summary_model()
2330 else {
2331 return;
2332 };
2333
2334 if !provider.is_authenticated(cx) {
2335 return;
2336 }
2337
2338 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2339
2340 let request = self.to_summarize_request(
2341 &model,
2342 CompletionIntent::ThreadContextSummarization,
2343 added_user_message.into(),
2344 cx,
2345 );
2346
2347 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2348 message_id: last_message_id,
2349 };
2350
2351 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2352 // be better to allow the old task to complete, but this would require logic for choosing
2353 // which result to prefer (the old task could complete after the new one, resulting in a
2354 // stale summary).
2355 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2356 let stream = model.stream_completion_text(request, &cx);
2357 let Some(mut messages) = stream.await.log_err() else {
2358 thread
2359 .update(cx, |thread, _cx| {
2360 *thread.detailed_summary_tx.borrow_mut() =
2361 DetailedSummaryState::NotGenerated;
2362 })
2363 .ok()?;
2364 return None;
2365 };
2366
2367 let mut new_detailed_summary = String::new();
2368
2369 while let Some(chunk) = messages.stream.next().await {
2370 if let Some(chunk) = chunk.log_err() {
2371 new_detailed_summary.push_str(&chunk);
2372 }
2373 }
2374
2375 thread
2376 .update(cx, |thread, _cx| {
2377 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2378 text: new_detailed_summary.into(),
2379 message_id: last_message_id,
2380 };
2381 })
2382 .ok()?;
2383
2384 // Save thread so its summary can be reused later
2385 if let Some(thread) = thread.upgrade() {
2386 if let Ok(Ok(save_task)) = cx.update(|cx| {
2387 thread_store
2388 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2389 }) {
2390 save_task.await.log_err();
2391 }
2392 }
2393
2394 Some(())
2395 });
2396 }
2397
2398 pub async fn wait_for_detailed_summary_or_text(
2399 this: &Entity<Self>,
2400 cx: &mut AsyncApp,
2401 ) -> Option<SharedString> {
2402 let mut detailed_summary_rx = this
2403 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2404 .ok()?;
2405 loop {
2406 match detailed_summary_rx.recv().await? {
2407 DetailedSummaryState::Generating { .. } => {}
2408 DetailedSummaryState::NotGenerated => {
2409 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2410 }
2411 DetailedSummaryState::Generated { text, .. } => return Some(text),
2412 }
2413 }
2414 }
2415
2416 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2417 self.detailed_summary_rx
2418 .borrow()
2419 .text()
2420 .unwrap_or_else(|| self.text().into())
2421 }
2422
2423 pub fn is_generating_detailed_summary(&self) -> bool {
2424 matches!(
2425 &*self.detailed_summary_rx.borrow(),
2426 DetailedSummaryState::Generating { .. }
2427 )
2428 }
2429
2430 pub fn use_pending_tools(
2431 &mut self,
2432 window: Option<AnyWindowHandle>,
2433 model: Arc<dyn LanguageModel>,
2434 cx: &mut Context<Self>,
2435 ) -> Vec<PendingToolUse> {
2436 self.auto_capture_telemetry(cx);
2437 let request =
2438 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2439 let pending_tool_uses = self
2440 .tool_use
2441 .pending_tool_uses()
2442 .into_iter()
2443 .filter(|tool_use| tool_use.status.is_idle())
2444 .cloned()
2445 .collect::<Vec<_>>();
2446
2447 for tool_use in pending_tool_uses.iter() {
2448 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2449 }
2450
2451 pending_tool_uses
2452 }
2453
2454 fn use_pending_tool(
2455 &mut self,
2456 tool_use: PendingToolUse,
2457 request: Arc<LanguageModelRequest>,
2458 model: Arc<dyn LanguageModel>,
2459 window: Option<AnyWindowHandle>,
2460 cx: &mut Context<Self>,
2461 ) {
2462 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2463 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2464 };
2465
2466 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2467 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2468 }
2469
2470 if tool.needs_confirmation(&tool_use.input, cx)
2471 && !AgentSettings::get_global(cx).always_allow_tool_actions
2472 {
2473 self.tool_use.confirm_tool_use(
2474 tool_use.id,
2475 tool_use.ui_text,
2476 tool_use.input,
2477 request,
2478 tool,
2479 );
2480 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2481 } else {
2482 self.run_tool(
2483 tool_use.id,
2484 tool_use.ui_text,
2485 tool_use.input,
2486 request,
2487 tool,
2488 model,
2489 window,
2490 cx,
2491 );
2492 }
2493 }
2494
2495 pub fn handle_hallucinated_tool_use(
2496 &mut self,
2497 tool_use_id: LanguageModelToolUseId,
2498 hallucinated_tool_name: Arc<str>,
2499 window: Option<AnyWindowHandle>,
2500 cx: &mut Context<Thread>,
2501 ) {
2502 let available_tools = self.profile.enabled_tools(cx);
2503
2504 let tool_list = available_tools
2505 .iter()
2506 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2507 .collect::<Vec<_>>()
2508 .join("\n");
2509
2510 let error_message = format!(
2511 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2512 hallucinated_tool_name, tool_list
2513 );
2514
2515 let pending_tool_use = self.tool_use.insert_tool_output(
2516 tool_use_id.clone(),
2517 hallucinated_tool_name,
2518 Err(anyhow!("Missing tool call: {error_message}")),
2519 self.configured_model.as_ref(),
2520 self.completion_mode,
2521 );
2522
2523 cx.emit(ThreadEvent::MissingToolUse {
2524 tool_use_id: tool_use_id.clone(),
2525 ui_text: error_message.into(),
2526 });
2527
2528 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2529 }
2530
2531 pub fn receive_invalid_tool_json(
2532 &mut self,
2533 tool_use_id: LanguageModelToolUseId,
2534 tool_name: Arc<str>,
2535 invalid_json: Arc<str>,
2536 error: String,
2537 window: Option<AnyWindowHandle>,
2538 cx: &mut Context<Thread>,
2539 ) {
2540 log::error!("The model returned invalid input JSON: {invalid_json}");
2541
2542 let pending_tool_use = self.tool_use.insert_tool_output(
2543 tool_use_id.clone(),
2544 tool_name,
2545 Err(anyhow!("Error parsing input JSON: {error}")),
2546 self.configured_model.as_ref(),
2547 self.completion_mode,
2548 );
2549 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2550 pending_tool_use.ui_text.clone()
2551 } else {
2552 log::error!(
2553 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2554 );
2555 format!("Unknown tool {}", tool_use_id).into()
2556 };
2557
2558 cx.emit(ThreadEvent::InvalidToolInput {
2559 tool_use_id: tool_use_id.clone(),
2560 ui_text,
2561 invalid_input_json: invalid_json,
2562 });
2563
2564 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2565 }
2566
2567 pub fn run_tool(
2568 &mut self,
2569 tool_use_id: LanguageModelToolUseId,
2570 ui_text: impl Into<SharedString>,
2571 input: serde_json::Value,
2572 request: Arc<LanguageModelRequest>,
2573 tool: Arc<dyn Tool>,
2574 model: Arc<dyn LanguageModel>,
2575 window: Option<AnyWindowHandle>,
2576 cx: &mut Context<Thread>,
2577 ) {
2578 let task =
2579 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2580 self.tool_use
2581 .run_pending_tool(tool_use_id, ui_text.into(), task);
2582 }
2583
2584 fn spawn_tool_use(
2585 &mut self,
2586 tool_use_id: LanguageModelToolUseId,
2587 request: Arc<LanguageModelRequest>,
2588 input: serde_json::Value,
2589 tool: Arc<dyn Tool>,
2590 model: Arc<dyn LanguageModel>,
2591 window: Option<AnyWindowHandle>,
2592 cx: &mut Context<Thread>,
2593 ) -> Task<()> {
2594 let tool_name: Arc<str> = tool.name().into();
2595
2596 let tool_result = tool.run(
2597 input,
2598 request,
2599 self.project.clone(),
2600 self.action_log.clone(),
2601 model,
2602 window,
2603 cx,
2604 );
2605
2606 // Store the card separately if it exists
2607 if let Some(card) = tool_result.card.clone() {
2608 self.tool_use
2609 .insert_tool_result_card(tool_use_id.clone(), card);
2610 }
2611
2612 cx.spawn({
2613 async move |thread: WeakEntity<Thread>, cx| {
2614 let output = tool_result.output.await;
2615
2616 thread
2617 .update(cx, |thread, cx| {
2618 let pending_tool_use = thread.tool_use.insert_tool_output(
2619 tool_use_id.clone(),
2620 tool_name,
2621 output,
2622 thread.configured_model.as_ref(),
2623 thread.completion_mode,
2624 );
2625 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2626 })
2627 .ok();
2628 }
2629 })
2630 }
2631
2632 fn tool_finished(
2633 &mut self,
2634 tool_use_id: LanguageModelToolUseId,
2635 pending_tool_use: Option<PendingToolUse>,
2636 canceled: bool,
2637 window: Option<AnyWindowHandle>,
2638 cx: &mut Context<Self>,
2639 ) {
2640 if self.all_tools_finished() {
2641 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2642 if !canceled {
2643 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2644 }
2645 self.auto_capture_telemetry(cx);
2646 }
2647 }
2648
2649 cx.emit(ThreadEvent::ToolFinished {
2650 tool_use_id,
2651 pending_tool_use,
2652 });
2653 }
2654
2655 /// Cancels the last pending completion, if there are any pending.
2656 ///
2657 /// Returns whether a completion was canceled.
2658 pub fn cancel_last_completion(
2659 &mut self,
2660 window: Option<AnyWindowHandle>,
2661 cx: &mut Context<Self>,
2662 ) -> bool {
2663 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2664
2665 self.retry_state = None;
2666
2667 for pending_tool_use in self.tool_use.cancel_pending() {
2668 canceled = true;
2669 self.tool_finished(
2670 pending_tool_use.id.clone(),
2671 Some(pending_tool_use),
2672 true,
2673 window,
2674 cx,
2675 );
2676 }
2677
2678 if canceled {
2679 cx.emit(ThreadEvent::CompletionCanceled);
2680
2681 // When canceled, we always want to insert the checkpoint.
2682 // (We skip over finalize_pending_checkpoint, because it
2683 // would conclude we didn't have anything to insert here.)
2684 if let Some(checkpoint) = self.pending_checkpoint.take() {
2685 self.insert_checkpoint(checkpoint, cx);
2686 }
2687 } else {
2688 self.finalize_pending_checkpoint(cx);
2689 }
2690
2691 canceled
2692 }
2693
2694 /// Signals that any in-progress editing should be canceled.
2695 ///
2696 /// This method is used to notify listeners (like ActiveThread) that
2697 /// they should cancel any editing operations.
2698 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2699 cx.emit(ThreadEvent::CancelEditing);
2700 }
2701
2702 pub fn feedback(&self) -> Option<ThreadFeedback> {
2703 self.feedback
2704 }
2705
2706 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2707 self.message_feedback.get(&message_id).copied()
2708 }
2709
2710 pub fn report_message_feedback(
2711 &mut self,
2712 message_id: MessageId,
2713 feedback: ThreadFeedback,
2714 cx: &mut Context<Self>,
2715 ) -> Task<Result<()>> {
2716 if self.message_feedback.get(&message_id) == Some(&feedback) {
2717 return Task::ready(Ok(()));
2718 }
2719
2720 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2721 let serialized_thread = self.serialize(cx);
2722 let thread_id = self.id().clone();
2723 let client = self.project.read(cx).client();
2724
2725 let enabled_tool_names: Vec<String> = self
2726 .profile
2727 .enabled_tools(cx)
2728 .iter()
2729 .map(|(name, _)| name.clone().into())
2730 .collect();
2731
2732 self.message_feedback.insert(message_id, feedback);
2733
2734 cx.notify();
2735
2736 let message_content = self
2737 .message(message_id)
2738 .map(|msg| msg.to_string())
2739 .unwrap_or_default();
2740
2741 cx.background_spawn(async move {
2742 let final_project_snapshot = final_project_snapshot.await;
2743 let serialized_thread = serialized_thread.await?;
2744 let thread_data =
2745 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2746
2747 let rating = match feedback {
2748 ThreadFeedback::Positive => "positive",
2749 ThreadFeedback::Negative => "negative",
2750 };
2751 telemetry::event!(
2752 "Assistant Thread Rated",
2753 rating,
2754 thread_id,
2755 enabled_tool_names,
2756 message_id = message_id.0,
2757 message_content,
2758 thread_data,
2759 final_project_snapshot
2760 );
2761 client.telemetry().flush_events().await;
2762
2763 Ok(())
2764 })
2765 }
2766
2767 pub fn report_feedback(
2768 &mut self,
2769 feedback: ThreadFeedback,
2770 cx: &mut Context<Self>,
2771 ) -> Task<Result<()>> {
2772 let last_assistant_message_id = self
2773 .messages
2774 .iter()
2775 .rev()
2776 .find(|msg| msg.role == Role::Assistant)
2777 .map(|msg| msg.id);
2778
2779 if let Some(message_id) = last_assistant_message_id {
2780 self.report_message_feedback(message_id, feedback, cx)
2781 } else {
2782 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2783 let serialized_thread = self.serialize(cx);
2784 let thread_id = self.id().clone();
2785 let client = self.project.read(cx).client();
2786 self.feedback = Some(feedback);
2787 cx.notify();
2788
2789 cx.background_spawn(async move {
2790 let final_project_snapshot = final_project_snapshot.await;
2791 let serialized_thread = serialized_thread.await?;
2792 let thread_data = serde_json::to_value(serialized_thread)
2793 .unwrap_or_else(|_| serde_json::Value::Null);
2794
2795 let rating = match feedback {
2796 ThreadFeedback::Positive => "positive",
2797 ThreadFeedback::Negative => "negative",
2798 };
2799 telemetry::event!(
2800 "Assistant Thread Rated",
2801 rating,
2802 thread_id,
2803 thread_data,
2804 final_project_snapshot
2805 );
2806 client.telemetry().flush_events().await;
2807
2808 Ok(())
2809 })
2810 }
2811 }
2812
2813 /// Create a snapshot of the current project state including git information and unsaved buffers.
2814 fn project_snapshot(
2815 project: Entity<Project>,
2816 cx: &mut Context<Self>,
2817 ) -> Task<Arc<ProjectSnapshot>> {
2818 let git_store = project.read(cx).git_store().clone();
2819 let worktree_snapshots: Vec<_> = project
2820 .read(cx)
2821 .visible_worktrees(cx)
2822 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2823 .collect();
2824
2825 cx.spawn(async move |_, cx| {
2826 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2827
2828 let mut unsaved_buffers = Vec::new();
2829 cx.update(|app_cx| {
2830 let buffer_store = project.read(app_cx).buffer_store();
2831 for buffer_handle in buffer_store.read(app_cx).buffers() {
2832 let buffer = buffer_handle.read(app_cx);
2833 if buffer.is_dirty() {
2834 if let Some(file) = buffer.file() {
2835 let path = file.path().to_string_lossy().to_string();
2836 unsaved_buffers.push(path);
2837 }
2838 }
2839 }
2840 })
2841 .ok();
2842
2843 Arc::new(ProjectSnapshot {
2844 worktree_snapshots,
2845 unsaved_buffer_paths: unsaved_buffers,
2846 timestamp: Utc::now(),
2847 })
2848 })
2849 }
2850
2851 fn worktree_snapshot(
2852 worktree: Entity<project::Worktree>,
2853 git_store: Entity<GitStore>,
2854 cx: &App,
2855 ) -> Task<WorktreeSnapshot> {
2856 cx.spawn(async move |cx| {
2857 // Get worktree path and snapshot
2858 let worktree_info = cx.update(|app_cx| {
2859 let worktree = worktree.read(app_cx);
2860 let path = worktree.abs_path().to_string_lossy().to_string();
2861 let snapshot = worktree.snapshot();
2862 (path, snapshot)
2863 });
2864
2865 let Ok((worktree_path, _snapshot)) = worktree_info else {
2866 return WorktreeSnapshot {
2867 worktree_path: String::new(),
2868 git_state: None,
2869 };
2870 };
2871
2872 let git_state = git_store
2873 .update(cx, |git_store, cx| {
2874 git_store
2875 .repositories()
2876 .values()
2877 .find(|repo| {
2878 repo.read(cx)
2879 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2880 .is_some()
2881 })
2882 .cloned()
2883 })
2884 .ok()
2885 .flatten()
2886 .map(|repo| {
2887 repo.update(cx, |repo, _| {
2888 let current_branch =
2889 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2890 repo.send_job(None, |state, _| async move {
2891 let RepositoryState::Local { backend, .. } = state else {
2892 return GitState {
2893 remote_url: None,
2894 head_sha: None,
2895 current_branch,
2896 diff: None,
2897 };
2898 };
2899
2900 let remote_url = backend.remote_url("origin");
2901 let head_sha = backend.head_sha().await;
2902 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2903
2904 GitState {
2905 remote_url,
2906 head_sha,
2907 current_branch,
2908 diff,
2909 }
2910 })
2911 })
2912 });
2913
2914 let git_state = match git_state {
2915 Some(git_state) => match git_state.ok() {
2916 Some(git_state) => git_state.await.ok(),
2917 None => None,
2918 },
2919 None => None,
2920 };
2921
2922 WorktreeSnapshot {
2923 worktree_path,
2924 git_state,
2925 }
2926 })
2927 }
2928
2929 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2930 let mut markdown = Vec::new();
2931
2932 let summary = self.summary().or_default();
2933 writeln!(markdown, "# {summary}\n")?;
2934
2935 for message in self.messages() {
2936 writeln!(
2937 markdown,
2938 "## {role}\n",
2939 role = match message.role {
2940 Role::User => "User",
2941 Role::Assistant => "Agent",
2942 Role::System => "System",
2943 }
2944 )?;
2945
2946 if !message.loaded_context.text.is_empty() {
2947 writeln!(markdown, "{}", message.loaded_context.text)?;
2948 }
2949
2950 if !message.loaded_context.images.is_empty() {
2951 writeln!(
2952 markdown,
2953 "\n{} images attached as context.\n",
2954 message.loaded_context.images.len()
2955 )?;
2956 }
2957
2958 for segment in &message.segments {
2959 match segment {
2960 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2961 MessageSegment::Thinking { text, .. } => {
2962 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2963 }
2964 MessageSegment::RedactedThinking(_) => {}
2965 }
2966 }
2967
2968 for tool_use in self.tool_uses_for_message(message.id, cx) {
2969 writeln!(
2970 markdown,
2971 "**Use Tool: {} ({})**",
2972 tool_use.name, tool_use.id
2973 )?;
2974 writeln!(markdown, "```json")?;
2975 writeln!(
2976 markdown,
2977 "{}",
2978 serde_json::to_string_pretty(&tool_use.input)?
2979 )?;
2980 writeln!(markdown, "```")?;
2981 }
2982
2983 for tool_result in self.tool_results_for_message(message.id) {
2984 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2985 if tool_result.is_error {
2986 write!(markdown, " (Error)")?;
2987 }
2988
2989 writeln!(markdown, "**\n")?;
2990 match &tool_result.content {
2991 LanguageModelToolResultContent::Text(text) => {
2992 writeln!(markdown, "{text}")?;
2993 }
2994 LanguageModelToolResultContent::Image(image) => {
2995 writeln!(markdown, "", image.source)?;
2996 }
2997 }
2998
2999 if let Some(output) = tool_result.output.as_ref() {
3000 writeln!(
3001 markdown,
3002 "\n\nDebug Output:\n\n```json\n{}\n```\n",
3003 serde_json::to_string_pretty(output)?
3004 )?;
3005 }
3006 }
3007 }
3008
3009 Ok(String::from_utf8_lossy(&markdown).to_string())
3010 }
3011
3012 pub fn keep_edits_in_range(
3013 &mut self,
3014 buffer: Entity<language::Buffer>,
3015 buffer_range: Range<language::Anchor>,
3016 cx: &mut Context<Self>,
3017 ) {
3018 self.action_log.update(cx, |action_log, cx| {
3019 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3020 });
3021 }
3022
3023 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3024 self.action_log
3025 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3026 }
3027
3028 pub fn reject_edits_in_ranges(
3029 &mut self,
3030 buffer: Entity<language::Buffer>,
3031 buffer_ranges: Vec<Range<language::Anchor>>,
3032 cx: &mut Context<Self>,
3033 ) -> Task<Result<()>> {
3034 self.action_log.update(cx, |action_log, cx| {
3035 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3036 })
3037 }
3038
3039 pub fn action_log(&self) -> &Entity<ActionLog> {
3040 &self.action_log
3041 }
3042
3043 pub fn project(&self) -> &Entity<Project> {
3044 &self.project
3045 }
3046
3047 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3048 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3049 return;
3050 }
3051
3052 let now = Instant::now();
3053 if let Some(last) = self.last_auto_capture_at {
3054 if now.duration_since(last).as_secs() < 10 {
3055 return;
3056 }
3057 }
3058
3059 self.last_auto_capture_at = Some(now);
3060
3061 let thread_id = self.id().clone();
3062 let github_login = self
3063 .project
3064 .read(cx)
3065 .user_store()
3066 .read(cx)
3067 .current_user()
3068 .map(|user| user.github_login.clone());
3069 let client = self.project.read(cx).client();
3070 let serialize_task = self.serialize(cx);
3071
3072 cx.background_executor()
3073 .spawn(async move {
3074 if let Ok(serialized_thread) = serialize_task.await {
3075 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3076 telemetry::event!(
3077 "Agent Thread Auto-Captured",
3078 thread_id = thread_id.to_string(),
3079 thread_data = thread_data,
3080 auto_capture_reason = "tracked_user",
3081 github_login = github_login
3082 );
3083
3084 client.telemetry().flush_events().await;
3085 }
3086 }
3087 })
3088 .detach();
3089 }
3090
3091 pub fn cumulative_token_usage(&self) -> TokenUsage {
3092 self.cumulative_token_usage
3093 }
3094
3095 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3096 let Some(model) = self.configured_model.as_ref() else {
3097 return TotalTokenUsage::default();
3098 };
3099
3100 let max = model
3101 .model
3102 .max_token_count_for_mode(self.completion_mode().into());
3103
3104 let index = self
3105 .messages
3106 .iter()
3107 .position(|msg| msg.id == message_id)
3108 .unwrap_or(0);
3109
3110 if index == 0 {
3111 return TotalTokenUsage { total: 0, max };
3112 }
3113
3114 let token_usage = &self
3115 .request_token_usage
3116 .get(index - 1)
3117 .cloned()
3118 .unwrap_or_default();
3119
3120 TotalTokenUsage {
3121 total: token_usage.total_tokens(),
3122 max,
3123 }
3124 }
3125
3126 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3127 let model = self.configured_model.as_ref()?;
3128
3129 let max = model
3130 .model
3131 .max_token_count_for_mode(self.completion_mode().into());
3132
3133 if let Some(exceeded_error) = &self.exceeded_window_error {
3134 if model.model.id() == exceeded_error.model_id {
3135 return Some(TotalTokenUsage {
3136 total: exceeded_error.token_count,
3137 max,
3138 });
3139 }
3140 }
3141
3142 let total = self
3143 .token_usage_at_last_message()
3144 .unwrap_or_default()
3145 .total_tokens();
3146
3147 Some(TotalTokenUsage { total, max })
3148 }
3149
3150 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3151 self.request_token_usage
3152 .get(self.messages.len().saturating_sub(1))
3153 .or_else(|| self.request_token_usage.last())
3154 .cloned()
3155 }
3156
3157 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3158 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3159 self.request_token_usage
3160 .resize(self.messages.len(), placeholder);
3161
3162 if let Some(last) = self.request_token_usage.last_mut() {
3163 *last = token_usage;
3164 }
3165 }
3166
3167 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3168 self.project.update(cx, |project, cx| {
3169 project.user_store().update(cx, |user_store, cx| {
3170 user_store.update_model_request_usage(
3171 ModelRequestUsage(RequestUsage {
3172 amount: amount as i32,
3173 limit,
3174 }),
3175 cx,
3176 )
3177 })
3178 });
3179 }
3180
3181 pub fn deny_tool_use(
3182 &mut self,
3183 tool_use_id: LanguageModelToolUseId,
3184 tool_name: Arc<str>,
3185 window: Option<AnyWindowHandle>,
3186 cx: &mut Context<Self>,
3187 ) {
3188 let err = Err(anyhow::anyhow!(
3189 "Permission to run tool action denied by user"
3190 ));
3191
3192 self.tool_use.insert_tool_output(
3193 tool_use_id.clone(),
3194 tool_name,
3195 err,
3196 self.configured_model.as_ref(),
3197 self.completion_mode,
3198 );
3199 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3200 }
3201}
3202
3203#[derive(Debug, Clone, Error)]
3204pub enum ThreadError {
3205 #[error("Payment required")]
3206 PaymentRequired,
3207 #[error("Model request limit reached")]
3208 ModelRequestLimitReached { plan: Plan },
3209 #[error("Message {header}: {message}")]
3210 Message {
3211 header: SharedString,
3212 message: SharedString,
3213 },
3214}
3215
3216#[derive(Debug, Clone)]
3217pub enum ThreadEvent {
3218 ShowError(ThreadError),
3219 StreamedCompletion,
3220 ReceivedTextChunk,
3221 NewRequest,
3222 StreamedAssistantText(MessageId, String),
3223 StreamedAssistantThinking(MessageId, String),
3224 StreamedToolUse {
3225 tool_use_id: LanguageModelToolUseId,
3226 ui_text: Arc<str>,
3227 input: serde_json::Value,
3228 },
3229 MissingToolUse {
3230 tool_use_id: LanguageModelToolUseId,
3231 ui_text: Arc<str>,
3232 },
3233 InvalidToolInput {
3234 tool_use_id: LanguageModelToolUseId,
3235 ui_text: Arc<str>,
3236 invalid_input_json: Arc<str>,
3237 },
3238 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3239 MessageAdded(MessageId),
3240 MessageEdited(MessageId),
3241 MessageDeleted(MessageId),
3242 SummaryGenerated,
3243 SummaryChanged,
3244 UsePendingTools {
3245 tool_uses: Vec<PendingToolUse>,
3246 },
3247 ToolFinished {
3248 #[allow(unused)]
3249 tool_use_id: LanguageModelToolUseId,
3250 /// The pending tool use that corresponds to this tool.
3251 pending_tool_use: Option<PendingToolUse>,
3252 },
3253 CheckpointChanged,
3254 ToolConfirmationNeeded,
3255 ToolUseLimitReached,
3256 CancelEditing,
3257 CompletionCanceled,
3258 ProfileChanged,
3259 RetriesFailed {
3260 message: SharedString,
3261 },
3262}
3263
3264impl EventEmitter<ThreadEvent> for Thread {}
3265
3266struct PendingCompletion {
3267 id: usize,
3268 queue_state: QueueState,
3269 _task: Task<()>,
3270}
3271
3272#[cfg(test)]
3273mod tests {
3274 use super::*;
3275 use crate::{
3276 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3277 };
3278
3279 // Test-specific constants
3280 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3281 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3282 use assistant_tool::ToolRegistry;
3283 use assistant_tools;
3284 use futures::StreamExt;
3285 use futures::future::BoxFuture;
3286 use futures::stream::BoxStream;
3287 use gpui::TestAppContext;
3288 use http_client;
3289 use indoc::indoc;
3290 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3291 use language_model::{
3292 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3293 LanguageModelProviderName, LanguageModelToolChoice,
3294 };
3295 use parking_lot::Mutex;
3296 use project::{FakeFs, Project};
3297 use prompt_store::PromptBuilder;
3298 use serde_json::json;
3299 use settings::{Settings, SettingsStore};
3300 use std::sync::Arc;
3301 use std::time::Duration;
3302 use theme::ThemeSettings;
3303 use util::path;
3304 use workspace::Workspace;
3305
3306 #[gpui::test]
3307 async fn test_message_with_context(cx: &mut TestAppContext) {
3308 init_test_settings(cx);
3309
3310 let project = create_test_project(
3311 cx,
3312 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3313 )
3314 .await;
3315
3316 let (_workspace, _thread_store, thread, context_store, model) =
3317 setup_test_environment(cx, project.clone()).await;
3318
3319 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3320 .await
3321 .unwrap();
3322
3323 let context =
3324 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3325 let loaded_context = cx
3326 .update(|cx| load_context(vec![context], &project, &None, cx))
3327 .await;
3328
3329 // Insert user message with context
3330 let message_id = thread.update(cx, |thread, cx| {
3331 thread.insert_user_message(
3332 "Please explain this code",
3333 loaded_context,
3334 None,
3335 Vec::new(),
3336 cx,
3337 )
3338 });
3339
3340 // Check content and context in message object
3341 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3342
3343 // Use different path format strings based on platform for the test
3344 #[cfg(windows)]
3345 let path_part = r"test\code.rs";
3346 #[cfg(not(windows))]
3347 let path_part = "test/code.rs";
3348
3349 let expected_context = format!(
3350 r#"
3351<context>
3352The following items were attached by the user. They are up-to-date and don't need to be re-read.
3353
3354<files>
3355```rs {path_part}
3356fn main() {{
3357 println!("Hello, world!");
3358}}
3359```
3360</files>
3361</context>
3362"#
3363 );
3364
3365 assert_eq!(message.role, Role::User);
3366 assert_eq!(message.segments.len(), 1);
3367 assert_eq!(
3368 message.segments[0],
3369 MessageSegment::Text("Please explain this code".to_string())
3370 );
3371 assert_eq!(message.loaded_context.text, expected_context);
3372
3373 // Check message in request
3374 let request = thread.update(cx, |thread, cx| {
3375 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3376 });
3377
3378 assert_eq!(request.messages.len(), 2);
3379 let expected_full_message = format!("{}Please explain this code", expected_context);
3380 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3381 }
3382
3383 #[gpui::test]
3384 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3385 init_test_settings(cx);
3386
3387 let project = create_test_project(
3388 cx,
3389 json!({
3390 "file1.rs": "fn function1() {}\n",
3391 "file2.rs": "fn function2() {}\n",
3392 "file3.rs": "fn function3() {}\n",
3393 "file4.rs": "fn function4() {}\n",
3394 }),
3395 )
3396 .await;
3397
3398 let (_, _thread_store, thread, context_store, model) =
3399 setup_test_environment(cx, project.clone()).await;
3400
3401 // First message with context 1
3402 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3403 .await
3404 .unwrap();
3405 let new_contexts = context_store.update(cx, |store, cx| {
3406 store.new_context_for_thread(thread.read(cx), None)
3407 });
3408 assert_eq!(new_contexts.len(), 1);
3409 let loaded_context = cx
3410 .update(|cx| load_context(new_contexts, &project, &None, cx))
3411 .await;
3412 let message1_id = thread.update(cx, |thread, cx| {
3413 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3414 });
3415
3416 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3417 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3418 .await
3419 .unwrap();
3420 let new_contexts = context_store.update(cx, |store, cx| {
3421 store.new_context_for_thread(thread.read(cx), None)
3422 });
3423 assert_eq!(new_contexts.len(), 1);
3424 let loaded_context = cx
3425 .update(|cx| load_context(new_contexts, &project, &None, cx))
3426 .await;
3427 let message2_id = thread.update(cx, |thread, cx| {
3428 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3429 });
3430
3431 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3432 //
3433 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3434 .await
3435 .unwrap();
3436 let new_contexts = context_store.update(cx, |store, cx| {
3437 store.new_context_for_thread(thread.read(cx), None)
3438 });
3439 assert_eq!(new_contexts.len(), 1);
3440 let loaded_context = cx
3441 .update(|cx| load_context(new_contexts, &project, &None, cx))
3442 .await;
3443 let message3_id = thread.update(cx, |thread, cx| {
3444 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3445 });
3446
3447 // Check what contexts are included in each message
3448 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3449 (
3450 thread.message(message1_id).unwrap().clone(),
3451 thread.message(message2_id).unwrap().clone(),
3452 thread.message(message3_id).unwrap().clone(),
3453 )
3454 });
3455
3456 // First message should include context 1
3457 assert!(message1.loaded_context.text.contains("file1.rs"));
3458
3459 // Second message should include only context 2 (not 1)
3460 assert!(!message2.loaded_context.text.contains("file1.rs"));
3461 assert!(message2.loaded_context.text.contains("file2.rs"));
3462
3463 // Third message should include only context 3 (not 1 or 2)
3464 assert!(!message3.loaded_context.text.contains("file1.rs"));
3465 assert!(!message3.loaded_context.text.contains("file2.rs"));
3466 assert!(message3.loaded_context.text.contains("file3.rs"));
3467
3468 // Check entire request to make sure all contexts are properly included
3469 let request = thread.update(cx, |thread, cx| {
3470 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3471 });
3472
3473 // The request should contain all 3 messages
3474 assert_eq!(request.messages.len(), 4);
3475
3476 // Check that the contexts are properly formatted in each message
3477 assert!(request.messages[1].string_contents().contains("file1.rs"));
3478 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3479 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3480
3481 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3482 assert!(request.messages[2].string_contents().contains("file2.rs"));
3483 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3484
3485 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3486 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3487 assert!(request.messages[3].string_contents().contains("file3.rs"));
3488
3489 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3490 .await
3491 .unwrap();
3492 let new_contexts = context_store.update(cx, |store, cx| {
3493 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3494 });
3495 assert_eq!(new_contexts.len(), 3);
3496 let loaded_context = cx
3497 .update(|cx| load_context(new_contexts, &project, &None, cx))
3498 .await
3499 .loaded_context;
3500
3501 assert!(!loaded_context.text.contains("file1.rs"));
3502 assert!(loaded_context.text.contains("file2.rs"));
3503 assert!(loaded_context.text.contains("file3.rs"));
3504 assert!(loaded_context.text.contains("file4.rs"));
3505
3506 let new_contexts = context_store.update(cx, |store, cx| {
3507 // Remove file4.rs
3508 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3509 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3510 });
3511 assert_eq!(new_contexts.len(), 2);
3512 let loaded_context = cx
3513 .update(|cx| load_context(new_contexts, &project, &None, cx))
3514 .await
3515 .loaded_context;
3516
3517 assert!(!loaded_context.text.contains("file1.rs"));
3518 assert!(loaded_context.text.contains("file2.rs"));
3519 assert!(loaded_context.text.contains("file3.rs"));
3520 assert!(!loaded_context.text.contains("file4.rs"));
3521
3522 let new_contexts = context_store.update(cx, |store, cx| {
3523 // Remove file3.rs
3524 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3525 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3526 });
3527 assert_eq!(new_contexts.len(), 1);
3528 let loaded_context = cx
3529 .update(|cx| load_context(new_contexts, &project, &None, cx))
3530 .await
3531 .loaded_context;
3532
3533 assert!(!loaded_context.text.contains("file1.rs"));
3534 assert!(loaded_context.text.contains("file2.rs"));
3535 assert!(!loaded_context.text.contains("file3.rs"));
3536 assert!(!loaded_context.text.contains("file4.rs"));
3537 }
3538
3539 #[gpui::test]
3540 async fn test_message_without_files(cx: &mut TestAppContext) {
3541 init_test_settings(cx);
3542
3543 let project = create_test_project(
3544 cx,
3545 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3546 )
3547 .await;
3548
3549 let (_, _thread_store, thread, _context_store, model) =
3550 setup_test_environment(cx, project.clone()).await;
3551
3552 // Insert user message without any context (empty context vector)
3553 let message_id = thread.update(cx, |thread, cx| {
3554 thread.insert_user_message(
3555 "What is the best way to learn Rust?",
3556 ContextLoadResult::default(),
3557 None,
3558 Vec::new(),
3559 cx,
3560 )
3561 });
3562
3563 // Check content and context in message object
3564 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3565
3566 // Context should be empty when no files are included
3567 assert_eq!(message.role, Role::User);
3568 assert_eq!(message.segments.len(), 1);
3569 assert_eq!(
3570 message.segments[0],
3571 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3572 );
3573 assert_eq!(message.loaded_context.text, "");
3574
3575 // Check message in request
3576 let request = thread.update(cx, |thread, cx| {
3577 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3578 });
3579
3580 assert_eq!(request.messages.len(), 2);
3581 assert_eq!(
3582 request.messages[1].string_contents(),
3583 "What is the best way to learn Rust?"
3584 );
3585
3586 // Add second message, also without context
3587 let message2_id = thread.update(cx, |thread, cx| {
3588 thread.insert_user_message(
3589 "Are there any good books?",
3590 ContextLoadResult::default(),
3591 None,
3592 Vec::new(),
3593 cx,
3594 )
3595 });
3596
3597 let message2 =
3598 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3599 assert_eq!(message2.loaded_context.text, "");
3600
3601 // Check that both messages appear in the request
3602 let request = thread.update(cx, |thread, cx| {
3603 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3604 });
3605
3606 assert_eq!(request.messages.len(), 3);
3607 assert_eq!(
3608 request.messages[1].string_contents(),
3609 "What is the best way to learn Rust?"
3610 );
3611 assert_eq!(
3612 request.messages[2].string_contents(),
3613 "Are there any good books?"
3614 );
3615 }
3616
3617 #[gpui::test]
3618 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3619 init_test_settings(cx);
3620
3621 let project = create_test_project(
3622 cx,
3623 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3624 )
3625 .await;
3626
3627 let (_workspace, _thread_store, thread, context_store, model) =
3628 setup_test_environment(cx, project.clone()).await;
3629
3630 // Add a buffer to the context. This will be a tracked buffer
3631 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3632 .await
3633 .unwrap();
3634
3635 let context = context_store
3636 .read_with(cx, |store, _| store.context().next().cloned())
3637 .unwrap();
3638 let loaded_context = cx
3639 .update(|cx| load_context(vec![context], &project, &None, cx))
3640 .await;
3641
3642 // Insert user message and assistant response
3643 thread.update(cx, |thread, cx| {
3644 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3645 thread.insert_assistant_message(
3646 vec![MessageSegment::Text("This code prints 42.".into())],
3647 cx,
3648 );
3649 });
3650
3651 // We shouldn't have a stale buffer notification yet
3652 let notifications = thread.read_with(cx, |thread, _| {
3653 find_tool_uses(thread, "project_notifications")
3654 });
3655 assert!(
3656 notifications.is_empty(),
3657 "Should not have stale buffer notification before buffer is modified"
3658 );
3659
3660 // Modify the buffer
3661 buffer.update(cx, |buffer, cx| {
3662 buffer.edit(
3663 [(1..1, "\n println!(\"Added a new line\");\n")],
3664 None,
3665 cx,
3666 );
3667 });
3668
3669 // Insert another user message
3670 thread.update(cx, |thread, cx| {
3671 thread.insert_user_message(
3672 "What does the code do now?",
3673 ContextLoadResult::default(),
3674 None,
3675 Vec::new(),
3676 cx,
3677 )
3678 });
3679
3680 // Check for the stale buffer warning
3681 thread.update(cx, |thread, cx| {
3682 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3683 });
3684
3685 let notifications = thread.read_with(cx, |thread, _cx| {
3686 find_tool_uses(thread, "project_notifications")
3687 });
3688
3689 let [notification] = notifications.as_slice() else {
3690 panic!("Should have a `project_notifications` tool use");
3691 };
3692
3693 let Some(notification_content) = notification.content.to_str() else {
3694 panic!("`project_notifications` should return text");
3695 };
3696
3697 let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3698
3699 These files have changed since the last read:
3700 - code.rs
3701 "};
3702 assert_eq!(notification_content, expected_content);
3703
3704 // Insert another user message and flush notifications again
3705 thread.update(cx, |thread, cx| {
3706 thread.insert_user_message(
3707 "Can you tell me more?",
3708 ContextLoadResult::default(),
3709 None,
3710 Vec::new(),
3711 cx,
3712 )
3713 });
3714
3715 thread.update(cx, |thread, cx| {
3716 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3717 });
3718
3719 // There should be no new notifications (we already flushed one)
3720 let notifications = thread.read_with(cx, |thread, _cx| {
3721 find_tool_uses(thread, "project_notifications")
3722 });
3723
3724 assert_eq!(
3725 notifications.len(),
3726 1,
3727 "Should still have only one notification after second flush - no duplicates"
3728 );
3729 }
3730
3731 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3732 thread
3733 .messages()
3734 .flat_map(|message| {
3735 thread
3736 .tool_results_for_message(message.id)
3737 .into_iter()
3738 .filter(|result| result.tool_name == tool_name.into())
3739 .cloned()
3740 .collect::<Vec<_>>()
3741 })
3742 .collect()
3743 }
3744
3745 #[gpui::test]
3746 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3747 init_test_settings(cx);
3748
3749 let project = create_test_project(
3750 cx,
3751 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3752 )
3753 .await;
3754
3755 let (_workspace, thread_store, thread, _context_store, _model) =
3756 setup_test_environment(cx, project.clone()).await;
3757
3758 // Check that we are starting with the default profile
3759 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3760 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3761 assert_eq!(
3762 profile,
3763 AgentProfile::new(AgentProfileId::default(), tool_set)
3764 );
3765 }
3766
3767 #[gpui::test]
3768 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3769 init_test_settings(cx);
3770
3771 let project = create_test_project(
3772 cx,
3773 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3774 )
3775 .await;
3776
3777 let (_workspace, thread_store, thread, _context_store, _model) =
3778 setup_test_environment(cx, project.clone()).await;
3779
3780 // Profile gets serialized with default values
3781 let serialized = thread
3782 .update(cx, |thread, cx| thread.serialize(cx))
3783 .await
3784 .unwrap();
3785
3786 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3787
3788 let deserialized = cx.update(|cx| {
3789 thread.update(cx, |thread, cx| {
3790 Thread::deserialize(
3791 thread.id.clone(),
3792 serialized,
3793 thread.project.clone(),
3794 thread.tools.clone(),
3795 thread.prompt_builder.clone(),
3796 thread.project_context.clone(),
3797 None,
3798 cx,
3799 )
3800 })
3801 });
3802 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3803
3804 assert_eq!(
3805 deserialized.profile,
3806 AgentProfile::new(AgentProfileId::default(), tool_set)
3807 );
3808 }
3809
3810 #[gpui::test]
3811 async fn test_temperature_setting(cx: &mut TestAppContext) {
3812 init_test_settings(cx);
3813
3814 let project = create_test_project(
3815 cx,
3816 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3817 )
3818 .await;
3819
3820 let (_workspace, _thread_store, thread, _context_store, model) =
3821 setup_test_environment(cx, project.clone()).await;
3822
3823 // Both model and provider
3824 cx.update(|cx| {
3825 AgentSettings::override_global(
3826 AgentSettings {
3827 model_parameters: vec![LanguageModelParameters {
3828 provider: Some(model.provider_id().0.to_string().into()),
3829 model: Some(model.id().0.clone()),
3830 temperature: Some(0.66),
3831 }],
3832 ..AgentSettings::get_global(cx).clone()
3833 },
3834 cx,
3835 );
3836 });
3837
3838 let request = thread.update(cx, |thread, cx| {
3839 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3840 });
3841 assert_eq!(request.temperature, Some(0.66));
3842
3843 // Only model
3844 cx.update(|cx| {
3845 AgentSettings::override_global(
3846 AgentSettings {
3847 model_parameters: vec![LanguageModelParameters {
3848 provider: None,
3849 model: Some(model.id().0.clone()),
3850 temperature: Some(0.66),
3851 }],
3852 ..AgentSettings::get_global(cx).clone()
3853 },
3854 cx,
3855 );
3856 });
3857
3858 let request = thread.update(cx, |thread, cx| {
3859 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3860 });
3861 assert_eq!(request.temperature, Some(0.66));
3862
3863 // Only provider
3864 cx.update(|cx| {
3865 AgentSettings::override_global(
3866 AgentSettings {
3867 model_parameters: vec![LanguageModelParameters {
3868 provider: Some(model.provider_id().0.to_string().into()),
3869 model: None,
3870 temperature: Some(0.66),
3871 }],
3872 ..AgentSettings::get_global(cx).clone()
3873 },
3874 cx,
3875 );
3876 });
3877
3878 let request = thread.update(cx, |thread, cx| {
3879 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3880 });
3881 assert_eq!(request.temperature, Some(0.66));
3882
3883 // Same model name, different provider
3884 cx.update(|cx| {
3885 AgentSettings::override_global(
3886 AgentSettings {
3887 model_parameters: vec![LanguageModelParameters {
3888 provider: Some("anthropic".into()),
3889 model: Some(model.id().0.clone()),
3890 temperature: Some(0.66),
3891 }],
3892 ..AgentSettings::get_global(cx).clone()
3893 },
3894 cx,
3895 );
3896 });
3897
3898 let request = thread.update(cx, |thread, cx| {
3899 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3900 });
3901 assert_eq!(request.temperature, None);
3902 }
3903
3904 #[gpui::test]
3905 async fn test_thread_summary(cx: &mut TestAppContext) {
3906 init_test_settings(cx);
3907
3908 let project = create_test_project(cx, json!({})).await;
3909
3910 let (_, _thread_store, thread, _context_store, model) =
3911 setup_test_environment(cx, project.clone()).await;
3912
3913 // Initial state should be pending
3914 thread.read_with(cx, |thread, _| {
3915 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3916 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3917 });
3918
3919 // Manually setting the summary should not be allowed in this state
3920 thread.update(cx, |thread, cx| {
3921 thread.set_summary("This should not work", cx);
3922 });
3923
3924 thread.read_with(cx, |thread, _| {
3925 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3926 });
3927
3928 // Send a message
3929 thread.update(cx, |thread, cx| {
3930 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3931 thread.send_to_model(
3932 model.clone(),
3933 CompletionIntent::ThreadSummarization,
3934 None,
3935 cx,
3936 );
3937 });
3938
3939 let fake_model = model.as_fake();
3940 simulate_successful_response(&fake_model, cx);
3941
3942 // Should start generating summary when there are >= 2 messages
3943 thread.read_with(cx, |thread, _| {
3944 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3945 });
3946
3947 // Should not be able to set the summary while generating
3948 thread.update(cx, |thread, cx| {
3949 thread.set_summary("This should not work either", cx);
3950 });
3951
3952 thread.read_with(cx, |thread, _| {
3953 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3954 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3955 });
3956
3957 cx.run_until_parked();
3958 fake_model.stream_last_completion_response("Brief");
3959 fake_model.stream_last_completion_response(" Introduction");
3960 fake_model.end_last_completion_stream();
3961 cx.run_until_parked();
3962
3963 // Summary should be set
3964 thread.read_with(cx, |thread, _| {
3965 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3966 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3967 });
3968
3969 // Now we should be able to set a summary
3970 thread.update(cx, |thread, cx| {
3971 thread.set_summary("Brief Intro", cx);
3972 });
3973
3974 thread.read_with(cx, |thread, _| {
3975 assert_eq!(thread.summary().or_default(), "Brief Intro");
3976 });
3977
3978 // Test setting an empty summary (should default to DEFAULT)
3979 thread.update(cx, |thread, cx| {
3980 thread.set_summary("", cx);
3981 });
3982
3983 thread.read_with(cx, |thread, _| {
3984 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3985 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3986 });
3987 }
3988
3989 #[gpui::test]
3990 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3991 init_test_settings(cx);
3992
3993 let project = create_test_project(cx, json!({})).await;
3994
3995 let (_, _thread_store, thread, _context_store, model) =
3996 setup_test_environment(cx, project.clone()).await;
3997
3998 test_summarize_error(&model, &thread, cx);
3999
4000 // Now we should be able to set a summary
4001 thread.update(cx, |thread, cx| {
4002 thread.set_summary("Brief Intro", cx);
4003 });
4004
4005 thread.read_with(cx, |thread, _| {
4006 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4007 assert_eq!(thread.summary().or_default(), "Brief Intro");
4008 });
4009 }
4010
4011 #[gpui::test]
4012 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
4013 init_test_settings(cx);
4014
4015 let project = create_test_project(cx, json!({})).await;
4016
4017 let (_, _thread_store, thread, _context_store, model) =
4018 setup_test_environment(cx, project.clone()).await;
4019
4020 test_summarize_error(&model, &thread, cx);
4021
4022 // Sending another message should not trigger another summarize request
4023 thread.update(cx, |thread, cx| {
4024 thread.insert_user_message(
4025 "How are you?",
4026 ContextLoadResult::default(),
4027 None,
4028 vec![],
4029 cx,
4030 );
4031 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4032 });
4033
4034 let fake_model = model.as_fake();
4035 simulate_successful_response(&fake_model, cx);
4036
4037 thread.read_with(cx, |thread, _| {
4038 // State is still Error, not Generating
4039 assert!(matches!(thread.summary(), ThreadSummary::Error));
4040 });
4041
4042 // But the summarize request can be invoked manually
4043 thread.update(cx, |thread, cx| {
4044 thread.summarize(cx);
4045 });
4046
4047 thread.read_with(cx, |thread, _| {
4048 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4049 });
4050
4051 cx.run_until_parked();
4052 fake_model.stream_last_completion_response("A successful summary");
4053 fake_model.end_last_completion_stream();
4054 cx.run_until_parked();
4055
4056 thread.read_with(cx, |thread, _| {
4057 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4058 assert_eq!(thread.summary().or_default(), "A successful summary");
4059 });
4060 }
4061
4062 // Helper to create a model that returns errors
4063 enum TestError {
4064 Overloaded,
4065 InternalServerError,
4066 }
4067
4068 struct ErrorInjector {
4069 inner: Arc<FakeLanguageModel>,
4070 error_type: TestError,
4071 }
4072
4073 impl ErrorInjector {
4074 fn new(error_type: TestError) -> Self {
4075 Self {
4076 inner: Arc::new(FakeLanguageModel::default()),
4077 error_type,
4078 }
4079 }
4080 }
4081
4082 impl LanguageModel for ErrorInjector {
4083 fn id(&self) -> LanguageModelId {
4084 self.inner.id()
4085 }
4086
4087 fn name(&self) -> LanguageModelName {
4088 self.inner.name()
4089 }
4090
4091 fn provider_id(&self) -> LanguageModelProviderId {
4092 self.inner.provider_id()
4093 }
4094
4095 fn provider_name(&self) -> LanguageModelProviderName {
4096 self.inner.provider_name()
4097 }
4098
4099 fn supports_tools(&self) -> bool {
4100 self.inner.supports_tools()
4101 }
4102
4103 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4104 self.inner.supports_tool_choice(choice)
4105 }
4106
4107 fn supports_images(&self) -> bool {
4108 self.inner.supports_images()
4109 }
4110
4111 fn telemetry_id(&self) -> String {
4112 self.inner.telemetry_id()
4113 }
4114
4115 fn max_token_count(&self) -> u64 {
4116 self.inner.max_token_count()
4117 }
4118
4119 fn count_tokens(
4120 &self,
4121 request: LanguageModelRequest,
4122 cx: &App,
4123 ) -> BoxFuture<'static, Result<u64>> {
4124 self.inner.count_tokens(request, cx)
4125 }
4126
4127 fn stream_completion(
4128 &self,
4129 _request: LanguageModelRequest,
4130 _cx: &AsyncApp,
4131 ) -> BoxFuture<
4132 'static,
4133 Result<
4134 BoxStream<
4135 'static,
4136 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4137 >,
4138 LanguageModelCompletionError,
4139 >,
4140 > {
4141 let error = match self.error_type {
4142 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4143 provider: self.provider_name(),
4144 retry_after: None,
4145 },
4146 TestError::InternalServerError => {
4147 LanguageModelCompletionError::ApiInternalServerError {
4148 provider: self.provider_name(),
4149 message: "I'm a teapot orbiting the sun".to_string(),
4150 }
4151 }
4152 };
4153 async move {
4154 let stream = futures::stream::once(async move { Err(error) });
4155 Ok(stream.boxed())
4156 }
4157 .boxed()
4158 }
4159
4160 fn as_fake(&self) -> &FakeLanguageModel {
4161 &self.inner
4162 }
4163 }
4164
4165 #[gpui::test]
4166 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4167 init_test_settings(cx);
4168
4169 let project = create_test_project(cx, json!({})).await;
4170 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4171
4172 // Create model that returns overloaded error
4173 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4174
4175 // Insert a user message
4176 thread.update(cx, |thread, cx| {
4177 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4178 });
4179
4180 // Start completion
4181 thread.update(cx, |thread, cx| {
4182 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4183 });
4184
4185 cx.run_until_parked();
4186
4187 thread.read_with(cx, |thread, _| {
4188 assert!(thread.retry_state.is_some(), "Should have retry state");
4189 let retry_state = thread.retry_state.as_ref().unwrap();
4190 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4191 assert_eq!(
4192 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4193 "Should have default max attempts"
4194 );
4195 });
4196
4197 // Check that a retry message was added
4198 thread.read_with(cx, |thread, _| {
4199 let mut messages = thread.messages();
4200 assert!(
4201 messages.any(|msg| {
4202 msg.role == Role::System
4203 && msg.ui_only
4204 && msg.segments.iter().any(|seg| {
4205 if let MessageSegment::Text(text) = seg {
4206 text.contains("overloaded")
4207 && text
4208 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4209 } else {
4210 false
4211 }
4212 })
4213 }),
4214 "Should have added a system retry message"
4215 );
4216 });
4217
4218 let retry_count = thread.update(cx, |thread, _| {
4219 thread
4220 .messages
4221 .iter()
4222 .filter(|m| {
4223 m.ui_only
4224 && m.segments.iter().any(|s| {
4225 if let MessageSegment::Text(text) = s {
4226 text.contains("Retrying") && text.contains("seconds")
4227 } else {
4228 false
4229 }
4230 })
4231 })
4232 .count()
4233 });
4234
4235 assert_eq!(retry_count, 1, "Should have one retry message");
4236 }
4237
4238 #[gpui::test]
4239 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4240 init_test_settings(cx);
4241
4242 let project = create_test_project(cx, json!({})).await;
4243 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4244
4245 // Create model that returns internal server error
4246 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4247
4248 // Insert a user message
4249 thread.update(cx, |thread, cx| {
4250 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4251 });
4252
4253 // Start completion
4254 thread.update(cx, |thread, cx| {
4255 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4256 });
4257
4258 cx.run_until_parked();
4259
4260 // Check retry state on thread
4261 thread.read_with(cx, |thread, _| {
4262 assert!(thread.retry_state.is_some(), "Should have retry state");
4263 let retry_state = thread.retry_state.as_ref().unwrap();
4264 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4265 assert_eq!(
4266 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4267 "Should have correct max attempts"
4268 );
4269 });
4270
4271 // Check that a retry message was added with provider name
4272 thread.read_with(cx, |thread, _| {
4273 let mut messages = thread.messages();
4274 assert!(
4275 messages.any(|msg| {
4276 msg.role == Role::System
4277 && msg.ui_only
4278 && msg.segments.iter().any(|seg| {
4279 if let MessageSegment::Text(text) = seg {
4280 text.contains("internal")
4281 && text.contains("Fake")
4282 && text
4283 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4284 } else {
4285 false
4286 }
4287 })
4288 }),
4289 "Should have added a system retry message with provider name"
4290 );
4291 });
4292
4293 // Count retry messages
4294 let retry_count = thread.update(cx, |thread, _| {
4295 thread
4296 .messages
4297 .iter()
4298 .filter(|m| {
4299 m.ui_only
4300 && m.segments.iter().any(|s| {
4301 if let MessageSegment::Text(text) = s {
4302 text.contains("Retrying") && text.contains("seconds")
4303 } else {
4304 false
4305 }
4306 })
4307 })
4308 .count()
4309 });
4310
4311 assert_eq!(retry_count, 1, "Should have one retry message");
4312 }
4313
4314 #[gpui::test]
4315 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4316 init_test_settings(cx);
4317
4318 let project = create_test_project(cx, json!({})).await;
4319 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4320
4321 // Create model that returns overloaded error
4322 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4323
4324 // Insert a user message
4325 thread.update(cx, |thread, cx| {
4326 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4327 });
4328
4329 // Track retry events and completion count
4330 // Track completion events
4331 let completion_count = Arc::new(Mutex::new(0));
4332 let completion_count_clone = completion_count.clone();
4333
4334 let _subscription = thread.update(cx, |_, cx| {
4335 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4336 if let ThreadEvent::NewRequest = event {
4337 *completion_count_clone.lock() += 1;
4338 }
4339 })
4340 });
4341
4342 // First attempt
4343 thread.update(cx, |thread, cx| {
4344 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4345 });
4346 cx.run_until_parked();
4347
4348 // Should have scheduled first retry - count retry messages
4349 let retry_count = thread.update(cx, |thread, _| {
4350 thread
4351 .messages
4352 .iter()
4353 .filter(|m| {
4354 m.ui_only
4355 && m.segments.iter().any(|s| {
4356 if let MessageSegment::Text(text) = s {
4357 text.contains("Retrying") && text.contains("seconds")
4358 } else {
4359 false
4360 }
4361 })
4362 })
4363 .count()
4364 });
4365 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4366
4367 // Check retry state
4368 thread.read_with(cx, |thread, _| {
4369 assert!(thread.retry_state.is_some(), "Should have retry state");
4370 let retry_state = thread.retry_state.as_ref().unwrap();
4371 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4372 });
4373
4374 // Advance clock for first retry
4375 cx.executor()
4376 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4377 cx.run_until_parked();
4378
4379 // Should have scheduled second retry - count retry messages
4380 let retry_count = thread.update(cx, |thread, _| {
4381 thread
4382 .messages
4383 .iter()
4384 .filter(|m| {
4385 m.ui_only
4386 && m.segments.iter().any(|s| {
4387 if let MessageSegment::Text(text) = s {
4388 text.contains("Retrying") && text.contains("seconds")
4389 } else {
4390 false
4391 }
4392 })
4393 })
4394 .count()
4395 });
4396 assert_eq!(retry_count, 2, "Should have scheduled second retry");
4397
4398 // Check retry state updated
4399 thread.read_with(cx, |thread, _| {
4400 assert!(thread.retry_state.is_some(), "Should have retry state");
4401 let retry_state = thread.retry_state.as_ref().unwrap();
4402 assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4403 assert_eq!(
4404 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4405 "Should have correct max attempts"
4406 );
4407 });
4408
4409 // Advance clock for second retry (exponential backoff)
4410 cx.executor()
4411 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4412 cx.run_until_parked();
4413
4414 // Should have scheduled third retry
4415 // Count all retry messages now
4416 let retry_count = thread.update(cx, |thread, _| {
4417 thread
4418 .messages
4419 .iter()
4420 .filter(|m| {
4421 m.ui_only
4422 && m.segments.iter().any(|s| {
4423 if let MessageSegment::Text(text) = s {
4424 text.contains("Retrying") && text.contains("seconds")
4425 } else {
4426 false
4427 }
4428 })
4429 })
4430 .count()
4431 });
4432 assert_eq!(
4433 retry_count, MAX_RETRY_ATTEMPTS as usize,
4434 "Should have scheduled third retry"
4435 );
4436
4437 // Check retry state updated
4438 thread.read_with(cx, |thread, _| {
4439 assert!(thread.retry_state.is_some(), "Should have retry state");
4440 let retry_state = thread.retry_state.as_ref().unwrap();
4441 assert_eq!(
4442 retry_state.attempt, MAX_RETRY_ATTEMPTS,
4443 "Should be at max retry attempt"
4444 );
4445 assert_eq!(
4446 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4447 "Should have correct max attempts"
4448 );
4449 });
4450
4451 // Advance clock for third retry (exponential backoff)
4452 cx.executor()
4453 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4454 cx.run_until_parked();
4455
4456 // No more retries should be scheduled after clock was advanced.
4457 let retry_count = thread.update(cx, |thread, _| {
4458 thread
4459 .messages
4460 .iter()
4461 .filter(|m| {
4462 m.ui_only
4463 && m.segments.iter().any(|s| {
4464 if let MessageSegment::Text(text) = s {
4465 text.contains("Retrying") && text.contains("seconds")
4466 } else {
4467 false
4468 }
4469 })
4470 })
4471 .count()
4472 });
4473 assert_eq!(
4474 retry_count, MAX_RETRY_ATTEMPTS as usize,
4475 "Should not exceed max retries"
4476 );
4477
4478 // Final completion count should be initial + max retries
4479 assert_eq!(
4480 *completion_count.lock(),
4481 (MAX_RETRY_ATTEMPTS + 1) as usize,
4482 "Should have made initial + max retry attempts"
4483 );
4484 }
4485
4486 #[gpui::test]
4487 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4488 init_test_settings(cx);
4489
4490 let project = create_test_project(cx, json!({})).await;
4491 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4492
4493 // Create model that returns overloaded error
4494 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4495
4496 // Insert a user message
4497 thread.update(cx, |thread, cx| {
4498 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4499 });
4500
4501 // Track events
4502 let retries_failed = Arc::new(Mutex::new(false));
4503 let retries_failed_clone = retries_failed.clone();
4504
4505 let _subscription = thread.update(cx, |_, cx| {
4506 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4507 if let ThreadEvent::RetriesFailed { .. } = event {
4508 *retries_failed_clone.lock() = true;
4509 }
4510 })
4511 });
4512
4513 // Start initial completion
4514 thread.update(cx, |thread, cx| {
4515 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4516 });
4517 cx.run_until_parked();
4518
4519 // Advance through all retries
4520 for i in 0..MAX_RETRY_ATTEMPTS {
4521 let delay = if i == 0 {
4522 BASE_RETRY_DELAY_SECS
4523 } else {
4524 BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4525 };
4526 cx.executor().advance_clock(Duration::from_secs(delay));
4527 cx.run_until_parked();
4528 }
4529
4530 // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4531 // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4532 let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4533 cx.executor()
4534 .advance_clock(Duration::from_secs(final_delay));
4535 cx.run_until_parked();
4536
4537 let retry_count = thread.update(cx, |thread, _| {
4538 thread
4539 .messages
4540 .iter()
4541 .filter(|m| {
4542 m.ui_only
4543 && m.segments.iter().any(|s| {
4544 if let MessageSegment::Text(text) = s {
4545 text.contains("Retrying") && text.contains("seconds")
4546 } else {
4547 false
4548 }
4549 })
4550 })
4551 .count()
4552 });
4553
4554 // After max retries, should emit RetriesFailed event
4555 assert_eq!(
4556 retry_count, MAX_RETRY_ATTEMPTS as usize,
4557 "Should have attempted max retries"
4558 );
4559 assert!(
4560 *retries_failed.lock(),
4561 "Should emit RetriesFailed event after max retries exceeded"
4562 );
4563
4564 // Retry state should be cleared
4565 thread.read_with(cx, |thread, _| {
4566 assert!(
4567 thread.retry_state.is_none(),
4568 "Retry state should be cleared after max retries"
4569 );
4570
4571 // Verify we have the expected number of retry messages
4572 let retry_messages = thread
4573 .messages
4574 .iter()
4575 .filter(|msg| msg.ui_only && msg.role == Role::System)
4576 .count();
4577 assert_eq!(
4578 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4579 "Should have one retry message per attempt"
4580 );
4581 });
4582 }
4583
4584 #[gpui::test]
4585 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4586 init_test_settings(cx);
4587
4588 let project = create_test_project(cx, json!({})).await;
4589 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4590
4591 // We'll use a wrapper to switch behavior after first failure
4592 struct RetryTestModel {
4593 inner: Arc<FakeLanguageModel>,
4594 failed_once: Arc<Mutex<bool>>,
4595 }
4596
4597 impl LanguageModel for RetryTestModel {
4598 fn id(&self) -> LanguageModelId {
4599 self.inner.id()
4600 }
4601
4602 fn name(&self) -> LanguageModelName {
4603 self.inner.name()
4604 }
4605
4606 fn provider_id(&self) -> LanguageModelProviderId {
4607 self.inner.provider_id()
4608 }
4609
4610 fn provider_name(&self) -> LanguageModelProviderName {
4611 self.inner.provider_name()
4612 }
4613
4614 fn supports_tools(&self) -> bool {
4615 self.inner.supports_tools()
4616 }
4617
4618 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4619 self.inner.supports_tool_choice(choice)
4620 }
4621
4622 fn supports_images(&self) -> bool {
4623 self.inner.supports_images()
4624 }
4625
4626 fn telemetry_id(&self) -> String {
4627 self.inner.telemetry_id()
4628 }
4629
4630 fn max_token_count(&self) -> u64 {
4631 self.inner.max_token_count()
4632 }
4633
4634 fn count_tokens(
4635 &self,
4636 request: LanguageModelRequest,
4637 cx: &App,
4638 ) -> BoxFuture<'static, Result<u64>> {
4639 self.inner.count_tokens(request, cx)
4640 }
4641
4642 fn stream_completion(
4643 &self,
4644 request: LanguageModelRequest,
4645 cx: &AsyncApp,
4646 ) -> BoxFuture<
4647 'static,
4648 Result<
4649 BoxStream<
4650 'static,
4651 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4652 >,
4653 LanguageModelCompletionError,
4654 >,
4655 > {
4656 if !*self.failed_once.lock() {
4657 *self.failed_once.lock() = true;
4658 let provider = self.provider_name();
4659 // Return error on first attempt
4660 let stream = futures::stream::once(async move {
4661 Err(LanguageModelCompletionError::ServerOverloaded {
4662 provider,
4663 retry_after: None,
4664 })
4665 });
4666 async move { Ok(stream.boxed()) }.boxed()
4667 } else {
4668 // Succeed on retry
4669 self.inner.stream_completion(request, cx)
4670 }
4671 }
4672
4673 fn as_fake(&self) -> &FakeLanguageModel {
4674 &self.inner
4675 }
4676 }
4677
4678 let model = Arc::new(RetryTestModel {
4679 inner: Arc::new(FakeLanguageModel::default()),
4680 failed_once: Arc::new(Mutex::new(false)),
4681 });
4682
4683 // Insert a user message
4684 thread.update(cx, |thread, cx| {
4685 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4686 });
4687
4688 // Track message deletions
4689 // Track when retry completes successfully
4690 let retry_completed = Arc::new(Mutex::new(false));
4691 let retry_completed_clone = retry_completed.clone();
4692
4693 let _subscription = thread.update(cx, |_, cx| {
4694 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4695 if let ThreadEvent::StreamedCompletion = event {
4696 *retry_completed_clone.lock() = true;
4697 }
4698 })
4699 });
4700
4701 // Start completion
4702 thread.update(cx, |thread, cx| {
4703 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4704 });
4705 cx.run_until_parked();
4706
4707 // Get the retry message ID
4708 let retry_message_id = thread.read_with(cx, |thread, _| {
4709 thread
4710 .messages()
4711 .find(|msg| msg.role == Role::System && msg.ui_only)
4712 .map(|msg| msg.id)
4713 .expect("Should have a retry message")
4714 });
4715
4716 // Wait for retry
4717 cx.executor()
4718 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4719 cx.run_until_parked();
4720
4721 // Stream some successful content
4722 let fake_model = model.as_fake();
4723 // After the retry, there should be a new pending completion
4724 let pending = fake_model.pending_completions();
4725 assert!(
4726 !pending.is_empty(),
4727 "Should have a pending completion after retry"
4728 );
4729 fake_model.stream_completion_response(&pending[0], "Success!");
4730 fake_model.end_completion_stream(&pending[0]);
4731 cx.run_until_parked();
4732
4733 // Check that the retry completed successfully
4734 assert!(
4735 *retry_completed.lock(),
4736 "Retry should have completed successfully"
4737 );
4738
4739 // Retry message should still exist but be marked as ui_only
4740 thread.read_with(cx, |thread, _| {
4741 let retry_msg = thread
4742 .message(retry_message_id)
4743 .expect("Retry message should still exist");
4744 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4745 assert_eq!(
4746 retry_msg.role,
4747 Role::System,
4748 "Retry message should have System role"
4749 );
4750 });
4751 }
4752
4753 #[gpui::test]
4754 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4755 init_test_settings(cx);
4756
4757 let project = create_test_project(cx, json!({})).await;
4758 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4759
4760 // Create a model that fails once then succeeds
4761 struct FailOnceModel {
4762 inner: Arc<FakeLanguageModel>,
4763 failed_once: Arc<Mutex<bool>>,
4764 }
4765
4766 impl LanguageModel for FailOnceModel {
4767 fn id(&self) -> LanguageModelId {
4768 self.inner.id()
4769 }
4770
4771 fn name(&self) -> LanguageModelName {
4772 self.inner.name()
4773 }
4774
4775 fn provider_id(&self) -> LanguageModelProviderId {
4776 self.inner.provider_id()
4777 }
4778
4779 fn provider_name(&self) -> LanguageModelProviderName {
4780 self.inner.provider_name()
4781 }
4782
4783 fn supports_tools(&self) -> bool {
4784 self.inner.supports_tools()
4785 }
4786
4787 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4788 self.inner.supports_tool_choice(choice)
4789 }
4790
4791 fn supports_images(&self) -> bool {
4792 self.inner.supports_images()
4793 }
4794
4795 fn telemetry_id(&self) -> String {
4796 self.inner.telemetry_id()
4797 }
4798
4799 fn max_token_count(&self) -> u64 {
4800 self.inner.max_token_count()
4801 }
4802
4803 fn count_tokens(
4804 &self,
4805 request: LanguageModelRequest,
4806 cx: &App,
4807 ) -> BoxFuture<'static, Result<u64>> {
4808 self.inner.count_tokens(request, cx)
4809 }
4810
4811 fn stream_completion(
4812 &self,
4813 request: LanguageModelRequest,
4814 cx: &AsyncApp,
4815 ) -> BoxFuture<
4816 'static,
4817 Result<
4818 BoxStream<
4819 'static,
4820 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4821 >,
4822 LanguageModelCompletionError,
4823 >,
4824 > {
4825 if !*self.failed_once.lock() {
4826 *self.failed_once.lock() = true;
4827 let provider = self.provider_name();
4828 // Return error on first attempt
4829 let stream = futures::stream::once(async move {
4830 Err(LanguageModelCompletionError::ServerOverloaded {
4831 provider,
4832 retry_after: None,
4833 })
4834 });
4835 async move { Ok(stream.boxed()) }.boxed()
4836 } else {
4837 // Succeed on retry
4838 self.inner.stream_completion(request, cx)
4839 }
4840 }
4841 }
4842
4843 let fail_once_model = Arc::new(FailOnceModel {
4844 inner: Arc::new(FakeLanguageModel::default()),
4845 failed_once: Arc::new(Mutex::new(false)),
4846 });
4847
4848 // Insert a user message
4849 thread.update(cx, |thread, cx| {
4850 thread.insert_user_message(
4851 "Test message",
4852 ContextLoadResult::default(),
4853 None,
4854 vec![],
4855 cx,
4856 );
4857 });
4858
4859 // Start completion with fail-once model
4860 thread.update(cx, |thread, cx| {
4861 thread.send_to_model(
4862 fail_once_model.clone(),
4863 CompletionIntent::UserPrompt,
4864 None,
4865 cx,
4866 );
4867 });
4868
4869 cx.run_until_parked();
4870
4871 // Verify retry state exists after first failure
4872 thread.read_with(cx, |thread, _| {
4873 assert!(
4874 thread.retry_state.is_some(),
4875 "Should have retry state after failure"
4876 );
4877 });
4878
4879 // Wait for retry delay
4880 cx.executor()
4881 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4882 cx.run_until_parked();
4883
4884 // The retry should now use our FailOnceModel which should succeed
4885 // We need to help the FakeLanguageModel complete the stream
4886 let inner_fake = fail_once_model.inner.clone();
4887
4888 // Wait a bit for the retry to start
4889 cx.run_until_parked();
4890
4891 // Check for pending completions and complete them
4892 if let Some(pending) = inner_fake.pending_completions().first() {
4893 inner_fake.stream_completion_response(pending, "Success!");
4894 inner_fake.end_completion_stream(pending);
4895 }
4896 cx.run_until_parked();
4897
4898 thread.read_with(cx, |thread, _| {
4899 assert!(
4900 thread.retry_state.is_none(),
4901 "Retry state should be cleared after successful completion"
4902 );
4903
4904 let has_assistant_message = thread
4905 .messages
4906 .iter()
4907 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4908 assert!(
4909 has_assistant_message,
4910 "Should have an assistant message after successful retry"
4911 );
4912 });
4913 }
4914
4915 #[gpui::test]
4916 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4917 init_test_settings(cx);
4918
4919 let project = create_test_project(cx, json!({})).await;
4920 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4921
4922 // Create a model that returns rate limit error with retry_after
4923 struct RateLimitModel {
4924 inner: Arc<FakeLanguageModel>,
4925 }
4926
4927 impl LanguageModel for RateLimitModel {
4928 fn id(&self) -> LanguageModelId {
4929 self.inner.id()
4930 }
4931
4932 fn name(&self) -> LanguageModelName {
4933 self.inner.name()
4934 }
4935
4936 fn provider_id(&self) -> LanguageModelProviderId {
4937 self.inner.provider_id()
4938 }
4939
4940 fn provider_name(&self) -> LanguageModelProviderName {
4941 self.inner.provider_name()
4942 }
4943
4944 fn supports_tools(&self) -> bool {
4945 self.inner.supports_tools()
4946 }
4947
4948 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4949 self.inner.supports_tool_choice(choice)
4950 }
4951
4952 fn supports_images(&self) -> bool {
4953 self.inner.supports_images()
4954 }
4955
4956 fn telemetry_id(&self) -> String {
4957 self.inner.telemetry_id()
4958 }
4959
4960 fn max_token_count(&self) -> u64 {
4961 self.inner.max_token_count()
4962 }
4963
4964 fn count_tokens(
4965 &self,
4966 request: LanguageModelRequest,
4967 cx: &App,
4968 ) -> BoxFuture<'static, Result<u64>> {
4969 self.inner.count_tokens(request, cx)
4970 }
4971
4972 fn stream_completion(
4973 &self,
4974 _request: LanguageModelRequest,
4975 _cx: &AsyncApp,
4976 ) -> BoxFuture<
4977 'static,
4978 Result<
4979 BoxStream<
4980 'static,
4981 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4982 >,
4983 LanguageModelCompletionError,
4984 >,
4985 > {
4986 let provider = self.provider_name();
4987 async move {
4988 let stream = futures::stream::once(async move {
4989 Err(LanguageModelCompletionError::RateLimitExceeded {
4990 provider,
4991 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4992 })
4993 });
4994 Ok(stream.boxed())
4995 }
4996 .boxed()
4997 }
4998
4999 fn as_fake(&self) -> &FakeLanguageModel {
5000 &self.inner
5001 }
5002 }
5003
5004 let model = Arc::new(RateLimitModel {
5005 inner: Arc::new(FakeLanguageModel::default()),
5006 });
5007
5008 // Insert a user message
5009 thread.update(cx, |thread, cx| {
5010 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5011 });
5012
5013 // Start completion
5014 thread.update(cx, |thread, cx| {
5015 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5016 });
5017
5018 cx.run_until_parked();
5019
5020 let retry_count = thread.update(cx, |thread, _| {
5021 thread
5022 .messages
5023 .iter()
5024 .filter(|m| {
5025 m.ui_only
5026 && m.segments.iter().any(|s| {
5027 if let MessageSegment::Text(text) = s {
5028 text.contains("rate limit exceeded")
5029 } else {
5030 false
5031 }
5032 })
5033 })
5034 .count()
5035 });
5036 assert_eq!(retry_count, 1, "Should have scheduled one retry");
5037
5038 thread.read_with(cx, |thread, _| {
5039 assert!(
5040 thread.retry_state.is_none(),
5041 "Rate limit errors should not set retry_state"
5042 );
5043 });
5044
5045 // Verify we have one retry message
5046 thread.read_with(cx, |thread, _| {
5047 let retry_messages = thread
5048 .messages
5049 .iter()
5050 .filter(|msg| {
5051 msg.ui_only
5052 && msg.segments.iter().any(|seg| {
5053 if let MessageSegment::Text(text) = seg {
5054 text.contains("rate limit exceeded")
5055 } else {
5056 false
5057 }
5058 })
5059 })
5060 .count();
5061 assert_eq!(
5062 retry_messages, 1,
5063 "Should have one rate limit retry message"
5064 );
5065 });
5066
5067 // Check that retry message doesn't include attempt count
5068 thread.read_with(cx, |thread, _| {
5069 let retry_message = thread
5070 .messages
5071 .iter()
5072 .find(|msg| msg.role == Role::System && msg.ui_only)
5073 .expect("Should have a retry message");
5074
5075 // Check that the message doesn't contain attempt count
5076 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5077 assert!(
5078 !text.contains("attempt"),
5079 "Rate limit retry message should not contain attempt count"
5080 );
5081 assert!(
5082 text.contains(&format!(
5083 "Retrying in {} seconds",
5084 TEST_RATE_LIMIT_RETRY_SECS
5085 )),
5086 "Rate limit retry message should contain retry delay"
5087 );
5088 }
5089 });
5090 }
5091
5092 #[gpui::test]
5093 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5094 init_test_settings(cx);
5095
5096 let project = create_test_project(cx, json!({})).await;
5097 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5098
5099 // Insert a regular user message
5100 thread.update(cx, |thread, cx| {
5101 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5102 });
5103
5104 // Insert a UI-only message (like our retry notifications)
5105 thread.update(cx, |thread, cx| {
5106 let id = thread.next_message_id.post_inc();
5107 thread.messages.push(Message {
5108 id,
5109 role: Role::System,
5110 segments: vec![MessageSegment::Text(
5111 "This is a UI-only message that should not be sent to the model".to_string(),
5112 )],
5113 loaded_context: LoadedContext::default(),
5114 creases: Vec::new(),
5115 is_hidden: true,
5116 ui_only: true,
5117 });
5118 cx.emit(ThreadEvent::MessageAdded(id));
5119 });
5120
5121 // Insert another regular message
5122 thread.update(cx, |thread, cx| {
5123 thread.insert_user_message(
5124 "How are you?",
5125 ContextLoadResult::default(),
5126 None,
5127 vec![],
5128 cx,
5129 );
5130 });
5131
5132 // Generate the completion request
5133 let request = thread.update(cx, |thread, cx| {
5134 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5135 });
5136
5137 // Verify that the request only contains non-UI-only messages
5138 // Should have system prompt + 2 user messages, but not the UI-only message
5139 let user_messages: Vec<_> = request
5140 .messages
5141 .iter()
5142 .filter(|msg| msg.role == Role::User)
5143 .collect();
5144 assert_eq!(
5145 user_messages.len(),
5146 2,
5147 "Should have exactly 2 user messages"
5148 );
5149
5150 // Verify the UI-only content is not present anywhere in the request
5151 let request_text = request
5152 .messages
5153 .iter()
5154 .flat_map(|msg| &msg.content)
5155 .filter_map(|content| match content {
5156 MessageContent::Text(text) => Some(text.as_str()),
5157 _ => None,
5158 })
5159 .collect::<String>();
5160
5161 assert!(
5162 !request_text.contains("UI-only message"),
5163 "UI-only message content should not be in the request"
5164 );
5165
5166 // Verify the thread still has all 3 messages (including UI-only)
5167 thread.read_with(cx, |thread, _| {
5168 assert_eq!(
5169 thread.messages().count(),
5170 3,
5171 "Thread should have 3 messages"
5172 );
5173 assert_eq!(
5174 thread.messages().filter(|m| m.ui_only).count(),
5175 1,
5176 "Thread should have 1 UI-only message"
5177 );
5178 });
5179
5180 // Verify that UI-only messages are not serialized
5181 let serialized = thread
5182 .update(cx, |thread, cx| thread.serialize(cx))
5183 .await
5184 .unwrap();
5185 assert_eq!(
5186 serialized.messages.len(),
5187 2,
5188 "Serialized thread should only have 2 messages (no UI-only)"
5189 );
5190 }
5191
5192 #[gpui::test]
5193 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5194 init_test_settings(cx);
5195
5196 let project = create_test_project(cx, json!({})).await;
5197 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5198
5199 // Create model that returns overloaded error
5200 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5201
5202 // Insert a user message
5203 thread.update(cx, |thread, cx| {
5204 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5205 });
5206
5207 // Start completion
5208 thread.update(cx, |thread, cx| {
5209 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5210 });
5211
5212 cx.run_until_parked();
5213
5214 // Verify retry was scheduled by checking for retry message
5215 let has_retry_message = thread.read_with(cx, |thread, _| {
5216 thread.messages.iter().any(|m| {
5217 m.ui_only
5218 && m.segments.iter().any(|s| {
5219 if let MessageSegment::Text(text) = s {
5220 text.contains("Retrying") && text.contains("seconds")
5221 } else {
5222 false
5223 }
5224 })
5225 })
5226 });
5227 assert!(has_retry_message, "Should have scheduled a retry");
5228
5229 // Cancel the completion before the retry happens
5230 thread.update(cx, |thread, cx| {
5231 thread.cancel_last_completion(None, cx);
5232 });
5233
5234 cx.run_until_parked();
5235
5236 // The retry should not have happened - no pending completions
5237 let fake_model = model.as_fake();
5238 assert_eq!(
5239 fake_model.pending_completions().len(),
5240 0,
5241 "Should have no pending completions after cancellation"
5242 );
5243
5244 // Verify the retry was cancelled by checking retry state
5245 thread.read_with(cx, |thread, _| {
5246 if let Some(retry_state) = &thread.retry_state {
5247 panic!(
5248 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5249 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5250 );
5251 }
5252 });
5253 }
5254
5255 fn test_summarize_error(
5256 model: &Arc<dyn LanguageModel>,
5257 thread: &Entity<Thread>,
5258 cx: &mut TestAppContext,
5259 ) {
5260 thread.update(cx, |thread, cx| {
5261 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5262 thread.send_to_model(
5263 model.clone(),
5264 CompletionIntent::ThreadSummarization,
5265 None,
5266 cx,
5267 );
5268 });
5269
5270 let fake_model = model.as_fake();
5271 simulate_successful_response(&fake_model, cx);
5272
5273 thread.read_with(cx, |thread, _| {
5274 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5275 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5276 });
5277
5278 // Simulate summary request ending
5279 cx.run_until_parked();
5280 fake_model.end_last_completion_stream();
5281 cx.run_until_parked();
5282
5283 // State is set to Error and default message
5284 thread.read_with(cx, |thread, _| {
5285 assert!(matches!(thread.summary(), ThreadSummary::Error));
5286 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5287 });
5288 }
5289
5290 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5291 cx.run_until_parked();
5292 fake_model.stream_last_completion_response("Assistant response");
5293 fake_model.end_last_completion_stream();
5294 cx.run_until_parked();
5295 }
5296
5297 fn init_test_settings(cx: &mut TestAppContext) {
5298 cx.update(|cx| {
5299 let settings_store = SettingsStore::test(cx);
5300 cx.set_global(settings_store);
5301 language::init(cx);
5302 Project::init_settings(cx);
5303 AgentSettings::register(cx);
5304 prompt_store::init(cx);
5305 thread_store::init(cx);
5306 workspace::init_settings(cx);
5307 language_model::init_settings(cx);
5308 ThemeSettings::register(cx);
5309 ToolRegistry::default_global(cx);
5310 assistant_tool::init(cx);
5311
5312 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5313 http_client::FakeHttpClient::with_200_response(),
5314 "http://localhost".to_string(),
5315 None,
5316 ));
5317 assistant_tools::init(http_client, cx);
5318 });
5319 }
5320
5321 // Helper to create a test project with test files
5322 async fn create_test_project(
5323 cx: &mut TestAppContext,
5324 files: serde_json::Value,
5325 ) -> Entity<Project> {
5326 let fs = FakeFs::new(cx.executor());
5327 fs.insert_tree(path!("/test"), files).await;
5328 Project::test(fs, [path!("/test").as_ref()], cx).await
5329 }
5330
5331 async fn setup_test_environment(
5332 cx: &mut TestAppContext,
5333 project: Entity<Project>,
5334 ) -> (
5335 Entity<Workspace>,
5336 Entity<ThreadStore>,
5337 Entity<Thread>,
5338 Entity<ContextStore>,
5339 Arc<dyn LanguageModel>,
5340 ) {
5341 let (workspace, cx) =
5342 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5343
5344 let thread_store = cx
5345 .update(|_, cx| {
5346 ThreadStore::load(
5347 project.clone(),
5348 cx.new(|_| ToolWorkingSet::default()),
5349 None,
5350 Arc::new(PromptBuilder::new(None).unwrap()),
5351 cx,
5352 )
5353 })
5354 .await
5355 .unwrap();
5356
5357 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5358 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5359
5360 let provider = Arc::new(FakeLanguageModelProvider);
5361 let model = provider.test_model();
5362 let model: Arc<dyn LanguageModel> = Arc::new(model);
5363
5364 cx.update(|_, cx| {
5365 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5366 registry.set_default_model(
5367 Some(ConfiguredModel {
5368 provider: provider.clone(),
5369 model: model.clone(),
5370 }),
5371 cx,
5372 );
5373 registry.set_thread_summary_model(
5374 Some(ConfiguredModel {
5375 provider,
5376 model: model.clone(),
5377 }),
5378 cx,
5379 );
5380 })
5381 });
5382
5383 (workspace, thread_store, thread, context_store, model)
5384 }
5385
5386 async fn add_file_to_context(
5387 project: &Entity<Project>,
5388 context_store: &Entity<ContextStore>,
5389 path: &str,
5390 cx: &mut TestAppContext,
5391 ) -> Result<Entity<language::Buffer>> {
5392 let buffer_path = project
5393 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5394 .unwrap();
5395
5396 let buffer = project
5397 .update(cx, |project, cx| {
5398 project.open_buffer(buffer_path.clone(), cx)
5399 })
5400 .await
5401 .unwrap();
5402
5403 context_store.update(cx, |context_store, cx| {
5404 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5405 });
5406
5407 Ok(buffer)
5408 }
5409}