1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::HashMap;
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
27 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
28 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
29 PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
30};
31use postage::stream::Stream as _;
32use project::{
33 Project,
34 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
35};
36use prompt_store::{ModelContext, PromptBuilder};
37use proto::Plan;
38use schemars::JsonSchema;
39use serde::{Deserialize, Serialize};
40use settings::Settings;
41use std::{
42 io::Write,
43 ops::Range,
44 sync::Arc,
45 time::{Duration, Instant},
46};
47use thiserror::Error;
48use util::{ResultExt as _, debug_panic, post_inc};
49use uuid::Uuid;
50use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
51
52const MAX_RETRY_ATTEMPTS: u8 = 3;
53const BASE_RETRY_DELAY_SECS: u64 = 5;
54
55#[derive(
56 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
57)]
58pub struct ThreadId(Arc<str>);
59
60impl ThreadId {
61 pub fn new() -> Self {
62 Self(Uuid::new_v4().to_string().into())
63 }
64}
65
66impl std::fmt::Display for ThreadId {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "{}", self.0)
69 }
70}
71
72impl From<&str> for ThreadId {
73 fn from(value: &str) -> Self {
74 Self(value.into())
75 }
76}
77
78/// The ID of the user prompt that initiated a request.
79///
80/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
81#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
82pub struct PromptId(Arc<str>);
83
84impl PromptId {
85 pub fn new() -> Self {
86 Self(Uuid::new_v4().to_string().into())
87 }
88}
89
90impl std::fmt::Display for PromptId {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
97pub struct MessageId(pub(crate) usize);
98
99impl MessageId {
100 fn post_inc(&mut self) -> Self {
101 Self(post_inc(&mut self.0))
102 }
103
104 pub fn as_usize(&self) -> usize {
105 self.0
106 }
107}
108
109/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
110#[derive(Clone, Debug)]
111pub struct MessageCrease {
112 pub range: Range<usize>,
113 pub icon_path: SharedString,
114 pub label: SharedString,
115 /// None for a deserialized message, Some otherwise.
116 pub context: Option<AgentContextHandle>,
117}
118
119/// A message in a [`Thread`].
120#[derive(Debug, Clone)]
121pub struct Message {
122 pub id: MessageId,
123 pub role: Role,
124 pub segments: Vec<MessageSegment>,
125 pub loaded_context: LoadedContext,
126 pub creases: Vec<MessageCrease>,
127 pub is_hidden: bool,
128 pub ui_only: bool,
129}
130
131impl Message {
132 /// Returns whether the message contains any meaningful text that should be displayed
133 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
134 pub fn should_display_content(&self) -> bool {
135 self.segments.iter().all(|segment| segment.should_display())
136 }
137
138 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
139 if let Some(MessageSegment::Thinking {
140 text: segment,
141 signature: current_signature,
142 }) = self.segments.last_mut()
143 {
144 if let Some(signature) = signature {
145 *current_signature = Some(signature);
146 }
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Thinking {
150 text: text.to_string(),
151 signature,
152 });
153 }
154 }
155
156 pub fn push_redacted_thinking(&mut self, data: String) {
157 self.segments.push(MessageSegment::RedactedThinking(data));
158 }
159
160 pub fn push_text(&mut self, text: &str) {
161 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
162 segment.push_str(text);
163 } else {
164 self.segments.push(MessageSegment::Text(text.to_string()));
165 }
166 }
167
168 pub fn to_string(&self) -> String {
169 let mut result = String::new();
170
171 if !self.loaded_context.text.is_empty() {
172 result.push_str(&self.loaded_context.text);
173 }
174
175 for segment in &self.segments {
176 match segment {
177 MessageSegment::Text(text) => result.push_str(text),
178 MessageSegment::Thinking { text, .. } => {
179 result.push_str("<think>\n");
180 result.push_str(text);
181 result.push_str("\n</think>");
182 }
183 MessageSegment::RedactedThinking(_) => {}
184 }
185 }
186
187 result
188 }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum MessageSegment {
193 Text(String),
194 Thinking {
195 text: String,
196 signature: Option<String>,
197 },
198 RedactedThinking(String),
199}
200
201impl MessageSegment {
202 pub fn should_display(&self) -> bool {
203 match self {
204 Self::Text(text) => text.is_empty(),
205 Self::Thinking { text, .. } => text.is_empty(),
206 Self::RedactedThinking(_) => false,
207 }
208 }
209
210 pub fn text(&self) -> Option<&str> {
211 match self {
212 MessageSegment::Text(text) => Some(text),
213 _ => None,
214 }
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
219pub struct ProjectSnapshot {
220 pub worktree_snapshots: Vec<WorktreeSnapshot>,
221 pub unsaved_buffer_paths: Vec<String>,
222 pub timestamp: DateTime<Utc>,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
226pub struct WorktreeSnapshot {
227 pub worktree_path: String,
228 pub git_state: Option<GitState>,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
232pub struct GitState {
233 pub remote_url: Option<String>,
234 pub head_sha: Option<String>,
235 pub current_branch: Option<String>,
236 pub diff: Option<String>,
237}
238
239#[derive(Clone, Debug)]
240pub struct ThreadCheckpoint {
241 message_id: MessageId,
242 git_checkpoint: GitStoreCheckpoint,
243}
244
245#[derive(Copy, Clone, Debug, PartialEq, Eq)]
246pub enum ThreadFeedback {
247 Positive,
248 Negative,
249}
250
251pub enum LastRestoreCheckpoint {
252 Pending {
253 message_id: MessageId,
254 },
255 Error {
256 message_id: MessageId,
257 error: String,
258 },
259}
260
261impl LastRestoreCheckpoint {
262 pub fn message_id(&self) -> MessageId {
263 match self {
264 LastRestoreCheckpoint::Pending { message_id } => *message_id,
265 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
266 }
267 }
268}
269
270#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
271pub enum DetailedSummaryState {
272 #[default]
273 NotGenerated,
274 Generating {
275 message_id: MessageId,
276 },
277 Generated {
278 text: SharedString,
279 message_id: MessageId,
280 },
281}
282
283impl DetailedSummaryState {
284 fn text(&self) -> Option<SharedString> {
285 if let Self::Generated { text, .. } = self {
286 Some(text.clone())
287 } else {
288 None
289 }
290 }
291}
292
293#[derive(Default, Debug)]
294pub struct TotalTokenUsage {
295 pub total: u64,
296 pub max: u64,
297}
298
299impl TotalTokenUsage {
300 pub fn ratio(&self) -> TokenUsageRatio {
301 #[cfg(debug_assertions)]
302 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
303 .unwrap_or("0.8".to_string())
304 .parse()
305 .unwrap();
306 #[cfg(not(debug_assertions))]
307 let warning_threshold: f32 = 0.8;
308
309 // When the maximum is unknown because there is no selected model,
310 // avoid showing the token limit warning.
311 if self.max == 0 {
312 TokenUsageRatio::Normal
313 } else if self.total >= self.max {
314 TokenUsageRatio::Exceeded
315 } else if self.total as f32 / self.max as f32 >= warning_threshold {
316 TokenUsageRatio::Warning
317 } else {
318 TokenUsageRatio::Normal
319 }
320 }
321
322 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
323 TotalTokenUsage {
324 total: self.total + tokens,
325 max: self.max,
326 }
327 }
328}
329
330#[derive(Debug, Default, PartialEq, Eq)]
331pub enum TokenUsageRatio {
332 #[default]
333 Normal,
334 Warning,
335 Exceeded,
336}
337
338#[derive(Debug, Clone, Copy)]
339pub enum QueueState {
340 Sending,
341 Queued { position: usize },
342 Started,
343}
344
345/// A thread of conversation with the LLM.
346pub struct Thread {
347 id: ThreadId,
348 updated_at: DateTime<Utc>,
349 summary: ThreadSummary,
350 pending_summary: Task<Option<()>>,
351 detailed_summary_task: Task<Option<()>>,
352 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
353 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
354 completion_mode: agent_settings::CompletionMode,
355 messages: Vec<Message>,
356 next_message_id: MessageId,
357 last_prompt_id: PromptId,
358 project_context: SharedProjectContext,
359 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
360 completion_count: usize,
361 pending_completions: Vec<PendingCompletion>,
362 project: Entity<Project>,
363 prompt_builder: Arc<PromptBuilder>,
364 tools: Entity<ToolWorkingSet>,
365 tool_use: ToolUseState,
366 action_log: Entity<ActionLog>,
367 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
368 pending_checkpoint: Option<ThreadCheckpoint>,
369 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
370 request_token_usage: Vec<TokenUsage>,
371 cumulative_token_usage: TokenUsage,
372 exceeded_window_error: Option<ExceededWindowError>,
373 tool_use_limit_reached: bool,
374 feedback: Option<ThreadFeedback>,
375 retry_state: Option<RetryState>,
376 message_feedback: HashMap<MessageId, ThreadFeedback>,
377 last_auto_capture_at: Option<Instant>,
378 last_received_chunk_at: Option<Instant>,
379 request_callback: Option<
380 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
381 >,
382 remaining_turns: u32,
383 configured_model: Option<ConfiguredModel>,
384 profile: AgentProfile,
385}
386
387#[derive(Clone, Debug)]
388struct RetryState {
389 attempt: u8,
390 max_attempts: u8,
391 intent: CompletionIntent,
392}
393
394#[derive(Clone, Debug, PartialEq, Eq)]
395pub enum ThreadSummary {
396 Pending,
397 Generating,
398 Ready(SharedString),
399 Error,
400}
401
402impl ThreadSummary {
403 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
404
405 pub fn or_default(&self) -> SharedString {
406 self.unwrap_or(Self::DEFAULT)
407 }
408
409 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
410 self.ready().unwrap_or_else(|| message.into())
411 }
412
413 pub fn ready(&self) -> Option<SharedString> {
414 match self {
415 ThreadSummary::Ready(summary) => Some(summary.clone()),
416 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
417 }
418 }
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
422pub struct ExceededWindowError {
423 /// Model used when last message exceeded context window
424 model_id: LanguageModelId,
425 /// Token count including last message
426 token_count: u64,
427}
428
429impl Thread {
430 pub fn new(
431 project: Entity<Project>,
432 tools: Entity<ToolWorkingSet>,
433 prompt_builder: Arc<PromptBuilder>,
434 system_prompt: SharedProjectContext,
435 cx: &mut Context<Self>,
436 ) -> Self {
437 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
438 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
439 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
440
441 Self {
442 id: ThreadId::new(),
443 updated_at: Utc::now(),
444 summary: ThreadSummary::Pending,
445 pending_summary: Task::ready(None),
446 detailed_summary_task: Task::ready(None),
447 detailed_summary_tx,
448 detailed_summary_rx,
449 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
450 messages: Vec::new(),
451 next_message_id: MessageId(0),
452 last_prompt_id: PromptId::new(),
453 project_context: system_prompt,
454 checkpoints_by_message: HashMap::default(),
455 completion_count: 0,
456 pending_completions: Vec::new(),
457 project: project.clone(),
458 prompt_builder,
459 tools: tools.clone(),
460 last_restore_checkpoint: None,
461 pending_checkpoint: None,
462 tool_use: ToolUseState::new(tools.clone()),
463 action_log: cx.new(|_| ActionLog::new(project.clone())),
464 initial_project_snapshot: {
465 let project_snapshot = Self::project_snapshot(project, cx);
466 cx.foreground_executor()
467 .spawn(async move { Some(project_snapshot.await) })
468 .shared()
469 },
470 request_token_usage: Vec::new(),
471 cumulative_token_usage: TokenUsage::default(),
472 exceeded_window_error: None,
473 tool_use_limit_reached: false,
474 feedback: None,
475 retry_state: None,
476 message_feedback: HashMap::default(),
477 last_auto_capture_at: None,
478 last_received_chunk_at: None,
479 request_callback: None,
480 remaining_turns: u32::MAX,
481 configured_model,
482 profile: AgentProfile::new(profile_id, tools),
483 }
484 }
485
486 pub fn deserialize(
487 id: ThreadId,
488 serialized: SerializedThread,
489 project: Entity<Project>,
490 tools: Entity<ToolWorkingSet>,
491 prompt_builder: Arc<PromptBuilder>,
492 project_context: SharedProjectContext,
493 window: Option<&mut Window>, // None in headless mode
494 cx: &mut Context<Self>,
495 ) -> Self {
496 let next_message_id = MessageId(
497 serialized
498 .messages
499 .last()
500 .map(|message| message.id.0 + 1)
501 .unwrap_or(0),
502 );
503 let tool_use = ToolUseState::from_serialized_messages(
504 tools.clone(),
505 &serialized.messages,
506 project.clone(),
507 window,
508 cx,
509 );
510 let (detailed_summary_tx, detailed_summary_rx) =
511 postage::watch::channel_with(serialized.detailed_summary_state);
512
513 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
514 serialized
515 .model
516 .and_then(|model| {
517 let model = SelectedModel {
518 provider: model.provider.clone().into(),
519 model: model.model.clone().into(),
520 };
521 registry.select_model(&model, cx)
522 })
523 .or_else(|| registry.default_model())
524 });
525
526 let completion_mode = serialized
527 .completion_mode
528 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
529 let profile_id = serialized
530 .profile
531 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
532
533 Self {
534 id,
535 updated_at: serialized.updated_at,
536 summary: ThreadSummary::Ready(serialized.summary),
537 pending_summary: Task::ready(None),
538 detailed_summary_task: Task::ready(None),
539 detailed_summary_tx,
540 detailed_summary_rx,
541 completion_mode,
542 retry_state: None,
543 messages: serialized
544 .messages
545 .into_iter()
546 .map(|message| Message {
547 id: message.id,
548 role: message.role,
549 segments: message
550 .segments
551 .into_iter()
552 .map(|segment| match segment {
553 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
554 SerializedMessageSegment::Thinking { text, signature } => {
555 MessageSegment::Thinking { text, signature }
556 }
557 SerializedMessageSegment::RedactedThinking { data } => {
558 MessageSegment::RedactedThinking(data)
559 }
560 })
561 .collect(),
562 loaded_context: LoadedContext {
563 contexts: Vec::new(),
564 text: message.context,
565 images: Vec::new(),
566 },
567 creases: message
568 .creases
569 .into_iter()
570 .map(|crease| MessageCrease {
571 range: crease.start..crease.end,
572 icon_path: crease.icon_path,
573 label: crease.label,
574 context: None,
575 })
576 .collect(),
577 is_hidden: message.is_hidden,
578 ui_only: false, // UI-only messages are not persisted
579 })
580 .collect(),
581 next_message_id,
582 last_prompt_id: PromptId::new(),
583 project_context,
584 checkpoints_by_message: HashMap::default(),
585 completion_count: 0,
586 pending_completions: Vec::new(),
587 last_restore_checkpoint: None,
588 pending_checkpoint: None,
589 project: project.clone(),
590 prompt_builder,
591 tools: tools.clone(),
592 tool_use,
593 action_log: cx.new(|_| ActionLog::new(project)),
594 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
595 request_token_usage: serialized.request_token_usage,
596 cumulative_token_usage: serialized.cumulative_token_usage,
597 exceeded_window_error: None,
598 tool_use_limit_reached: serialized.tool_use_limit_reached,
599 feedback: None,
600 message_feedback: HashMap::default(),
601 last_auto_capture_at: None,
602 last_received_chunk_at: None,
603 request_callback: None,
604 remaining_turns: u32::MAX,
605 configured_model,
606 profile: AgentProfile::new(profile_id, tools),
607 }
608 }
609
610 pub fn set_request_callback(
611 &mut self,
612 callback: impl 'static
613 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
614 ) {
615 self.request_callback = Some(Box::new(callback));
616 }
617
618 pub fn id(&self) -> &ThreadId {
619 &self.id
620 }
621
622 pub fn profile(&self) -> &AgentProfile {
623 &self.profile
624 }
625
626 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
627 if &id != self.profile.id() {
628 self.profile = AgentProfile::new(id, self.tools.clone());
629 cx.emit(ThreadEvent::ProfileChanged);
630 }
631 }
632
633 pub fn is_empty(&self) -> bool {
634 self.messages.is_empty()
635 }
636
637 pub fn updated_at(&self) -> DateTime<Utc> {
638 self.updated_at
639 }
640
641 pub fn touch_updated_at(&mut self) {
642 self.updated_at = Utc::now();
643 }
644
645 pub fn advance_prompt_id(&mut self) {
646 self.last_prompt_id = PromptId::new();
647 }
648
649 pub fn project_context(&self) -> SharedProjectContext {
650 self.project_context.clone()
651 }
652
653 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
654 if self.configured_model.is_none() {
655 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
656 }
657 self.configured_model.clone()
658 }
659
660 pub fn configured_model(&self) -> Option<ConfiguredModel> {
661 self.configured_model.clone()
662 }
663
664 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
665 self.configured_model = model;
666 cx.notify();
667 }
668
669 pub fn summary(&self) -> &ThreadSummary {
670 &self.summary
671 }
672
673 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
674 let current_summary = match &self.summary {
675 ThreadSummary::Pending | ThreadSummary::Generating => return,
676 ThreadSummary::Ready(summary) => summary,
677 ThreadSummary::Error => &ThreadSummary::DEFAULT,
678 };
679
680 let mut new_summary = new_summary.into();
681
682 if new_summary.is_empty() {
683 new_summary = ThreadSummary::DEFAULT;
684 }
685
686 if current_summary != &new_summary {
687 self.summary = ThreadSummary::Ready(new_summary);
688 cx.emit(ThreadEvent::SummaryChanged);
689 }
690 }
691
692 pub fn completion_mode(&self) -> CompletionMode {
693 self.completion_mode
694 }
695
696 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
697 self.completion_mode = mode;
698 }
699
700 pub fn message(&self, id: MessageId) -> Option<&Message> {
701 let index = self
702 .messages
703 .binary_search_by(|message| message.id.cmp(&id))
704 .ok()?;
705
706 self.messages.get(index)
707 }
708
709 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
710 self.messages.iter()
711 }
712
713 pub fn is_generating(&self) -> bool {
714 !self.pending_completions.is_empty() || !self.all_tools_finished()
715 }
716
717 /// Indicates whether streaming of language model events is stale.
718 /// When `is_generating()` is false, this method returns `None`.
719 pub fn is_generation_stale(&self) -> Option<bool> {
720 const STALE_THRESHOLD: u128 = 250;
721
722 self.last_received_chunk_at
723 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
724 }
725
726 fn received_chunk(&mut self) {
727 self.last_received_chunk_at = Some(Instant::now());
728 }
729
730 pub fn queue_state(&self) -> Option<QueueState> {
731 self.pending_completions
732 .first()
733 .map(|pending_completion| pending_completion.queue_state)
734 }
735
736 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
737 &self.tools
738 }
739
740 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
741 self.tool_use
742 .pending_tool_uses()
743 .into_iter()
744 .find(|tool_use| &tool_use.id == id)
745 }
746
747 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
748 self.tool_use
749 .pending_tool_uses()
750 .into_iter()
751 .filter(|tool_use| tool_use.status.needs_confirmation())
752 }
753
754 pub fn has_pending_tool_uses(&self) -> bool {
755 !self.tool_use.pending_tool_uses().is_empty()
756 }
757
758 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
759 self.checkpoints_by_message.get(&id).cloned()
760 }
761
762 pub fn restore_checkpoint(
763 &mut self,
764 checkpoint: ThreadCheckpoint,
765 cx: &mut Context<Self>,
766 ) -> Task<Result<()>> {
767 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
768 message_id: checkpoint.message_id,
769 });
770 cx.emit(ThreadEvent::CheckpointChanged);
771 cx.notify();
772
773 let git_store = self.project().read(cx).git_store().clone();
774 let restore = git_store.update(cx, |git_store, cx| {
775 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
776 });
777
778 cx.spawn(async move |this, cx| {
779 let result = restore.await;
780 this.update(cx, |this, cx| {
781 if let Err(err) = result.as_ref() {
782 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
783 message_id: checkpoint.message_id,
784 error: err.to_string(),
785 });
786 } else {
787 this.truncate(checkpoint.message_id, cx);
788 this.last_restore_checkpoint = None;
789 }
790 this.pending_checkpoint = None;
791 cx.emit(ThreadEvent::CheckpointChanged);
792 cx.notify();
793 })?;
794 result
795 })
796 }
797
798 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
799 let pending_checkpoint = if self.is_generating() {
800 return;
801 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
802 checkpoint
803 } else {
804 return;
805 };
806
807 self.finalize_checkpoint(pending_checkpoint, cx);
808 }
809
810 fn finalize_checkpoint(
811 &mut self,
812 pending_checkpoint: ThreadCheckpoint,
813 cx: &mut Context<Self>,
814 ) {
815 let git_store = self.project.read(cx).git_store().clone();
816 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
817 cx.spawn(async move |this, cx| match final_checkpoint.await {
818 Ok(final_checkpoint) => {
819 let equal = git_store
820 .update(cx, |store, cx| {
821 store.compare_checkpoints(
822 pending_checkpoint.git_checkpoint.clone(),
823 final_checkpoint.clone(),
824 cx,
825 )
826 })?
827 .await
828 .unwrap_or(false);
829
830 if !equal {
831 this.update(cx, |this, cx| {
832 this.insert_checkpoint(pending_checkpoint, cx)
833 })?;
834 }
835
836 Ok(())
837 }
838 Err(_) => this.update(cx, |this, cx| {
839 this.insert_checkpoint(pending_checkpoint, cx)
840 }),
841 })
842 .detach();
843 }
844
845 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
846 self.checkpoints_by_message
847 .insert(checkpoint.message_id, checkpoint);
848 cx.emit(ThreadEvent::CheckpointChanged);
849 cx.notify();
850 }
851
852 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
853 self.last_restore_checkpoint.as_ref()
854 }
855
856 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
857 let Some(message_ix) = self
858 .messages
859 .iter()
860 .rposition(|message| message.id == message_id)
861 else {
862 return;
863 };
864 for deleted_message in self.messages.drain(message_ix..) {
865 self.checkpoints_by_message.remove(&deleted_message.id);
866 }
867 cx.notify();
868 }
869
870 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
871 self.messages
872 .iter()
873 .find(|message| message.id == id)
874 .into_iter()
875 .flat_map(|message| message.loaded_context.contexts.iter())
876 }
877
878 pub fn is_turn_end(&self, ix: usize) -> bool {
879 if self.messages.is_empty() {
880 return false;
881 }
882
883 if !self.is_generating() && ix == self.messages.len() - 1 {
884 return true;
885 }
886
887 let Some(message) = self.messages.get(ix) else {
888 return false;
889 };
890
891 if message.role != Role::Assistant {
892 return false;
893 }
894
895 self.messages
896 .get(ix + 1)
897 .and_then(|message| {
898 self.message(message.id)
899 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
900 })
901 .unwrap_or(false)
902 }
903
904 pub fn tool_use_limit_reached(&self) -> bool {
905 self.tool_use_limit_reached
906 }
907
908 /// Returns whether all of the tool uses have finished running.
909 pub fn all_tools_finished(&self) -> bool {
910 // If the only pending tool uses left are the ones with errors, then
911 // that means that we've finished running all of the pending tools.
912 self.tool_use
913 .pending_tool_uses()
914 .iter()
915 .all(|pending_tool_use| pending_tool_use.status.is_error())
916 }
917
918 /// Returns whether any pending tool uses may perform edits
919 pub fn has_pending_edit_tool_uses(&self) -> bool {
920 self.tool_use
921 .pending_tool_uses()
922 .iter()
923 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
924 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
925 }
926
927 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
928 self.tool_use.tool_uses_for_message(id, cx)
929 }
930
931 pub fn tool_results_for_message(
932 &self,
933 assistant_message_id: MessageId,
934 ) -> Vec<&LanguageModelToolResult> {
935 self.tool_use.tool_results_for_message(assistant_message_id)
936 }
937
938 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
939 self.tool_use.tool_result(id)
940 }
941
942 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
943 match &self.tool_use.tool_result(id)?.content {
944 LanguageModelToolResultContent::Text(text) => Some(text),
945 LanguageModelToolResultContent::Image(_) => {
946 // TODO: We should display image
947 None
948 }
949 }
950 }
951
952 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
953 self.tool_use.tool_result_card(id).cloned()
954 }
955
956 /// Return tools that are both enabled and supported by the model
957 pub fn available_tools(
958 &self,
959 cx: &App,
960 model: Arc<dyn LanguageModel>,
961 ) -> Vec<LanguageModelRequestTool> {
962 if model.supports_tools() {
963 self.profile
964 .enabled_tools(cx)
965 .into_iter()
966 .filter_map(|(name, tool)| {
967 // Skip tools that cannot be supported
968 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
969 Some(LanguageModelRequestTool {
970 name: name.into(),
971 description: tool.description(),
972 input_schema,
973 })
974 })
975 .collect()
976 } else {
977 Vec::default()
978 }
979 }
980
981 pub fn insert_user_message(
982 &mut self,
983 text: impl Into<String>,
984 loaded_context: ContextLoadResult,
985 git_checkpoint: Option<GitStoreCheckpoint>,
986 creases: Vec<MessageCrease>,
987 cx: &mut Context<Self>,
988 ) -> MessageId {
989 if !loaded_context.referenced_buffers.is_empty() {
990 self.action_log.update(cx, |log, cx| {
991 for buffer in loaded_context.referenced_buffers {
992 log.buffer_read(buffer, cx);
993 }
994 });
995 }
996
997 let message_id = self.insert_message(
998 Role::User,
999 vec![MessageSegment::Text(text.into())],
1000 loaded_context.loaded_context,
1001 creases,
1002 false,
1003 cx,
1004 );
1005
1006 if let Some(git_checkpoint) = git_checkpoint {
1007 self.pending_checkpoint = Some(ThreadCheckpoint {
1008 message_id,
1009 git_checkpoint,
1010 });
1011 }
1012
1013 self.auto_capture_telemetry(cx);
1014
1015 message_id
1016 }
1017
1018 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019 let id = self.insert_message(
1020 Role::User,
1021 vec![MessageSegment::Text("Continue where you left off".into())],
1022 LoadedContext::default(),
1023 vec![],
1024 true,
1025 cx,
1026 );
1027 self.pending_checkpoint = None;
1028
1029 id
1030 }
1031
1032 pub fn insert_assistant_message(
1033 &mut self,
1034 segments: Vec<MessageSegment>,
1035 cx: &mut Context<Self>,
1036 ) -> MessageId {
1037 self.insert_message(
1038 Role::Assistant,
1039 segments,
1040 LoadedContext::default(),
1041 Vec::new(),
1042 false,
1043 cx,
1044 )
1045 }
1046
1047 pub fn insert_message(
1048 &mut self,
1049 role: Role,
1050 segments: Vec<MessageSegment>,
1051 loaded_context: LoadedContext,
1052 creases: Vec<MessageCrease>,
1053 is_hidden: bool,
1054 cx: &mut Context<Self>,
1055 ) -> MessageId {
1056 let id = self.next_message_id.post_inc();
1057 self.messages.push(Message {
1058 id,
1059 role,
1060 segments,
1061 loaded_context,
1062 creases,
1063 is_hidden,
1064 ui_only: false,
1065 });
1066 self.touch_updated_at();
1067 cx.emit(ThreadEvent::MessageAdded(id));
1068 id
1069 }
1070
1071 pub fn edit_message(
1072 &mut self,
1073 id: MessageId,
1074 new_role: Role,
1075 new_segments: Vec<MessageSegment>,
1076 creases: Vec<MessageCrease>,
1077 loaded_context: Option<LoadedContext>,
1078 checkpoint: Option<GitStoreCheckpoint>,
1079 cx: &mut Context<Self>,
1080 ) -> bool {
1081 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082 return false;
1083 };
1084 message.role = new_role;
1085 message.segments = new_segments;
1086 message.creases = creases;
1087 if let Some(context) = loaded_context {
1088 message.loaded_context = context;
1089 }
1090 if let Some(git_checkpoint) = checkpoint {
1091 self.checkpoints_by_message.insert(
1092 id,
1093 ThreadCheckpoint {
1094 message_id: id,
1095 git_checkpoint,
1096 },
1097 );
1098 }
1099 self.touch_updated_at();
1100 cx.emit(ThreadEvent::MessageEdited(id));
1101 true
1102 }
1103
1104 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106 return false;
1107 };
1108 self.messages.remove(index);
1109 self.touch_updated_at();
1110 cx.emit(ThreadEvent::MessageDeleted(id));
1111 true
1112 }
1113
1114 /// Returns the representation of this [`Thread`] in a textual form.
1115 ///
1116 /// This is the representation we use when attaching a thread as context to another thread.
1117 pub fn text(&self) -> String {
1118 let mut text = String::new();
1119
1120 for message in &self.messages {
1121 text.push_str(match message.role {
1122 language_model::Role::User => "User:",
1123 language_model::Role::Assistant => "Agent:",
1124 language_model::Role::System => "System:",
1125 });
1126 text.push('\n');
1127
1128 for segment in &message.segments {
1129 match segment {
1130 MessageSegment::Text(content) => text.push_str(content),
1131 MessageSegment::Thinking { text: content, .. } => {
1132 text.push_str(&format!("<think>{}</think>", content))
1133 }
1134 MessageSegment::RedactedThinking(_) => {}
1135 }
1136 }
1137 text.push('\n');
1138 }
1139
1140 text
1141 }
1142
1143 /// Serializes this thread into a format for storage or telemetry.
1144 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145 let initial_project_snapshot = self.initial_project_snapshot.clone();
1146 cx.spawn(async move |this, cx| {
1147 let initial_project_snapshot = initial_project_snapshot.await;
1148 this.read_with(cx, |this, cx| SerializedThread {
1149 version: SerializedThread::VERSION.to_string(),
1150 summary: this.summary().or_default(),
1151 updated_at: this.updated_at(),
1152 messages: this
1153 .messages()
1154 .filter(|message| !message.ui_only)
1155 .map(|message| SerializedMessage {
1156 id: message.id,
1157 role: message.role,
1158 segments: message
1159 .segments
1160 .iter()
1161 .map(|segment| match segment {
1162 MessageSegment::Text(text) => {
1163 SerializedMessageSegment::Text { text: text.clone() }
1164 }
1165 MessageSegment::Thinking { text, signature } => {
1166 SerializedMessageSegment::Thinking {
1167 text: text.clone(),
1168 signature: signature.clone(),
1169 }
1170 }
1171 MessageSegment::RedactedThinking(data) => {
1172 SerializedMessageSegment::RedactedThinking {
1173 data: data.clone(),
1174 }
1175 }
1176 })
1177 .collect(),
1178 tool_uses: this
1179 .tool_uses_for_message(message.id, cx)
1180 .into_iter()
1181 .map(|tool_use| SerializedToolUse {
1182 id: tool_use.id,
1183 name: tool_use.name,
1184 input: tool_use.input,
1185 })
1186 .collect(),
1187 tool_results: this
1188 .tool_results_for_message(message.id)
1189 .into_iter()
1190 .map(|tool_result| SerializedToolResult {
1191 tool_use_id: tool_result.tool_use_id.clone(),
1192 is_error: tool_result.is_error,
1193 content: tool_result.content.clone(),
1194 output: tool_result.output.clone(),
1195 })
1196 .collect(),
1197 context: message.loaded_context.text.clone(),
1198 creases: message
1199 .creases
1200 .iter()
1201 .map(|crease| SerializedCrease {
1202 start: crease.range.start,
1203 end: crease.range.end,
1204 icon_path: crease.icon_path.clone(),
1205 label: crease.label.clone(),
1206 })
1207 .collect(),
1208 is_hidden: message.is_hidden,
1209 })
1210 .collect(),
1211 initial_project_snapshot,
1212 cumulative_token_usage: this.cumulative_token_usage,
1213 request_token_usage: this.request_token_usage.clone(),
1214 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215 exceeded_window_error: this.exceeded_window_error.clone(),
1216 model: this
1217 .configured_model
1218 .as_ref()
1219 .map(|model| SerializedLanguageModel {
1220 provider: model.provider.id().0.to_string(),
1221 model: model.model.id().0.to_string(),
1222 }),
1223 completion_mode: Some(this.completion_mode),
1224 tool_use_limit_reached: this.tool_use_limit_reached,
1225 profile: Some(this.profile.id().clone()),
1226 })
1227 })
1228 }
1229
1230 pub fn remaining_turns(&self) -> u32 {
1231 self.remaining_turns
1232 }
1233
1234 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235 self.remaining_turns = remaining_turns;
1236 }
1237
1238 pub fn send_to_model(
1239 &mut self,
1240 model: Arc<dyn LanguageModel>,
1241 intent: CompletionIntent,
1242 window: Option<AnyWindowHandle>,
1243 cx: &mut Context<Self>,
1244 ) {
1245 if self.remaining_turns == 0 {
1246 return;
1247 }
1248
1249 self.remaining_turns -= 1;
1250
1251 self.flush_notifications(model.clone(), intent, cx);
1252
1253 let request = self.to_completion_request(model.clone(), intent, cx);
1254
1255 self.stream_completion(request, model, intent, window, cx);
1256 }
1257
1258 pub fn used_tools_since_last_user_message(&self) -> bool {
1259 for message in self.messages.iter().rev() {
1260 if self.tool_use.message_has_tool_results(message.id) {
1261 return true;
1262 } else if message.role == Role::User {
1263 return false;
1264 }
1265 }
1266
1267 false
1268 }
1269
1270 pub fn to_completion_request(
1271 &self,
1272 model: Arc<dyn LanguageModel>,
1273 intent: CompletionIntent,
1274 cx: &mut Context<Self>,
1275 ) -> LanguageModelRequest {
1276 let mut request = LanguageModelRequest {
1277 thread_id: Some(self.id.to_string()),
1278 prompt_id: Some(self.last_prompt_id.to_string()),
1279 intent: Some(intent),
1280 mode: None,
1281 messages: vec![],
1282 tools: Vec::new(),
1283 tool_choice: None,
1284 stop: Vec::new(),
1285 temperature: AgentSettings::temperature_for_model(&model, cx),
1286 };
1287
1288 let available_tools = self.available_tools(cx, model.clone());
1289 let available_tool_names = available_tools
1290 .iter()
1291 .map(|tool| tool.name.clone())
1292 .collect();
1293
1294 let model_context = &ModelContext {
1295 available_tools: available_tool_names,
1296 };
1297
1298 if let Some(project_context) = self.project_context.borrow().as_ref() {
1299 match self
1300 .prompt_builder
1301 .generate_assistant_system_prompt(project_context, model_context)
1302 {
1303 Err(err) => {
1304 let message = format!("{err:?}").into();
1305 log::error!("{message}");
1306 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1307 header: "Error generating system prompt".into(),
1308 message,
1309 }));
1310 }
1311 Ok(system_prompt) => {
1312 request.messages.push(LanguageModelRequestMessage {
1313 role: Role::System,
1314 content: vec![MessageContent::Text(system_prompt)],
1315 cache: true,
1316 });
1317 }
1318 }
1319 } else {
1320 let message = "Context for system prompt unexpectedly not ready.".into();
1321 log::error!("{message}");
1322 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1323 header: "Error generating system prompt".into(),
1324 message,
1325 }));
1326 }
1327
1328 let mut message_ix_to_cache = None;
1329 for message in &self.messages {
1330 // ui_only messages are for the UI only, not for the model
1331 if message.ui_only {
1332 continue;
1333 }
1334
1335 let mut request_message = LanguageModelRequestMessage {
1336 role: message.role,
1337 content: Vec::new(),
1338 cache: false,
1339 };
1340
1341 message
1342 .loaded_context
1343 .add_to_request_message(&mut request_message);
1344
1345 for segment in &message.segments {
1346 match segment {
1347 MessageSegment::Text(text) => {
1348 let text = text.trim_end();
1349 if !text.is_empty() {
1350 request_message
1351 .content
1352 .push(MessageContent::Text(text.into()));
1353 }
1354 }
1355 MessageSegment::Thinking { text, signature } => {
1356 if !text.is_empty() {
1357 request_message.content.push(MessageContent::Thinking {
1358 text: text.into(),
1359 signature: signature.clone(),
1360 });
1361 }
1362 }
1363 MessageSegment::RedactedThinking(data) => {
1364 request_message
1365 .content
1366 .push(MessageContent::RedactedThinking(data.clone()));
1367 }
1368 };
1369 }
1370
1371 let mut cache_message = true;
1372 let mut tool_results_message = LanguageModelRequestMessage {
1373 role: Role::User,
1374 content: Vec::new(),
1375 cache: false,
1376 };
1377 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1378 if let Some(tool_result) = tool_result {
1379 request_message
1380 .content
1381 .push(MessageContent::ToolUse(tool_use.clone()));
1382 tool_results_message
1383 .content
1384 .push(MessageContent::ToolResult(LanguageModelToolResult {
1385 tool_use_id: tool_use.id.clone(),
1386 tool_name: tool_result.tool_name.clone(),
1387 is_error: tool_result.is_error,
1388 content: if tool_result.content.is_empty() {
1389 // Surprisingly, the API fails if we return an empty string here.
1390 // It thinks we are sending a tool use without a tool result.
1391 "<Tool returned an empty string>".into()
1392 } else {
1393 tool_result.content.clone()
1394 },
1395 output: None,
1396 }));
1397 } else {
1398 cache_message = false;
1399 log::debug!(
1400 "skipped tool use {:?} because it is still pending",
1401 tool_use
1402 );
1403 }
1404 }
1405
1406 if cache_message {
1407 message_ix_to_cache = Some(request.messages.len());
1408 }
1409 request.messages.push(request_message);
1410
1411 if !tool_results_message.content.is_empty() {
1412 if cache_message {
1413 message_ix_to_cache = Some(request.messages.len());
1414 }
1415 request.messages.push(tool_results_message);
1416 }
1417 }
1418
1419 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1420 if let Some(message_ix_to_cache) = message_ix_to_cache {
1421 request.messages[message_ix_to_cache].cache = true;
1422 }
1423
1424 request.tools = available_tools;
1425 request.mode = if model.supports_burn_mode() {
1426 Some(self.completion_mode.into())
1427 } else {
1428 Some(CompletionMode::Normal.into())
1429 };
1430
1431 request
1432 }
1433
1434 fn to_summarize_request(
1435 &self,
1436 model: &Arc<dyn LanguageModel>,
1437 intent: CompletionIntent,
1438 added_user_message: String,
1439 cx: &App,
1440 ) -> LanguageModelRequest {
1441 let mut request = LanguageModelRequest {
1442 thread_id: None,
1443 prompt_id: None,
1444 intent: Some(intent),
1445 mode: None,
1446 messages: vec![],
1447 tools: Vec::new(),
1448 tool_choice: None,
1449 stop: Vec::new(),
1450 temperature: AgentSettings::temperature_for_model(model, cx),
1451 };
1452
1453 for message in &self.messages {
1454 let mut request_message = LanguageModelRequestMessage {
1455 role: message.role,
1456 content: Vec::new(),
1457 cache: false,
1458 };
1459
1460 for segment in &message.segments {
1461 match segment {
1462 MessageSegment::Text(text) => request_message
1463 .content
1464 .push(MessageContent::Text(text.clone())),
1465 MessageSegment::Thinking { .. } => {}
1466 MessageSegment::RedactedThinking(_) => {}
1467 }
1468 }
1469
1470 if request_message.content.is_empty() {
1471 continue;
1472 }
1473
1474 request.messages.push(request_message);
1475 }
1476
1477 request.messages.push(LanguageModelRequestMessage {
1478 role: Role::User,
1479 content: vec![MessageContent::Text(added_user_message)],
1480 cache: false,
1481 });
1482
1483 request
1484 }
1485
1486 /// Insert auto-generated notifications (if any) to the thread
1487 fn flush_notifications(
1488 &mut self,
1489 model: Arc<dyn LanguageModel>,
1490 intent: CompletionIntent,
1491 cx: &mut Context<Self>,
1492 ) {
1493 match intent {
1494 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1495 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1496 cx.emit(ThreadEvent::ToolFinished {
1497 tool_use_id: pending_tool_use.id.clone(),
1498 pending_tool_use: Some(pending_tool_use),
1499 });
1500 }
1501 }
1502 CompletionIntent::ThreadSummarization
1503 | CompletionIntent::ThreadContextSummarization
1504 | CompletionIntent::CreateFile
1505 | CompletionIntent::EditFile
1506 | CompletionIntent::InlineAssist
1507 | CompletionIntent::TerminalInlineAssist
1508 | CompletionIntent::GenerateGitCommitMessage => {}
1509 };
1510 }
1511
1512 fn attach_tracked_files_state(
1513 &mut self,
1514 model: Arc<dyn LanguageModel>,
1515 cx: &mut App,
1516 ) -> Option<PendingToolUse> {
1517 let action_log = self.action_log.read(cx);
1518
1519 action_log.unnotified_stale_buffers(cx).next()?;
1520
1521 // Represent notification as a simulated `project_notifications` tool call
1522 let tool_name = Arc::from("project_notifications");
1523 let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1524 debug_panic!("`project_notifications` tool not found");
1525 return None;
1526 };
1527
1528 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1529 return None;
1530 }
1531
1532 let input = serde_json::json!({});
1533 let request = Arc::new(LanguageModelRequest::default()); // unused
1534 let window = None;
1535 let tool_result = tool.run(
1536 input,
1537 request,
1538 self.project.clone(),
1539 self.action_log.clone(),
1540 model.clone(),
1541 window,
1542 cx,
1543 );
1544
1545 let tool_use_id =
1546 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1547
1548 let tool_use = LanguageModelToolUse {
1549 id: tool_use_id.clone(),
1550 name: tool_name.clone(),
1551 raw_input: "{}".to_string(),
1552 input: serde_json::json!({}),
1553 is_input_complete: true,
1554 };
1555
1556 let tool_output = cx.background_executor().block(tool_result.output);
1557
1558 // Attach a project_notification tool call to the latest existing
1559 // Assistant message. We cannot create a new Assistant message
1560 // because thinking models require a `thinking` block that we
1561 // cannot mock. We cannot send a notification as a normal
1562 // (non-tool-use) User message because this distracts Agent
1563 // too much.
1564 let tool_message_id = self
1565 .messages
1566 .iter()
1567 .enumerate()
1568 .rfind(|(_, message)| message.role == Role::Assistant)
1569 .map(|(_, message)| message.id)?;
1570
1571 let tool_use_metadata = ToolUseMetadata {
1572 model: model.clone(),
1573 thread_id: self.id.clone(),
1574 prompt_id: self.last_prompt_id.clone(),
1575 };
1576
1577 self.tool_use
1578 .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1579
1580 let pending_tool_use = self.tool_use.insert_tool_output(
1581 tool_use_id.clone(),
1582 tool_name,
1583 tool_output,
1584 self.configured_model.as_ref(),
1585 );
1586
1587 pending_tool_use
1588 }
1589
1590 pub fn stream_completion(
1591 &mut self,
1592 request: LanguageModelRequest,
1593 model: Arc<dyn LanguageModel>,
1594 intent: CompletionIntent,
1595 window: Option<AnyWindowHandle>,
1596 cx: &mut Context<Self>,
1597 ) {
1598 self.tool_use_limit_reached = false;
1599
1600 let pending_completion_id = post_inc(&mut self.completion_count);
1601 let mut request_callback_parameters = if self.request_callback.is_some() {
1602 Some((request.clone(), Vec::new()))
1603 } else {
1604 None
1605 };
1606 let prompt_id = self.last_prompt_id.clone();
1607 let tool_use_metadata = ToolUseMetadata {
1608 model: model.clone(),
1609 thread_id: self.id.clone(),
1610 prompt_id: prompt_id.clone(),
1611 };
1612
1613 self.last_received_chunk_at = Some(Instant::now());
1614
1615 let task = cx.spawn(async move |thread, cx| {
1616 let stream_completion_future = model.stream_completion(request, &cx);
1617 let initial_token_usage =
1618 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1619 let stream_completion = async {
1620 let mut events = stream_completion_future.await?;
1621
1622 let mut stop_reason = StopReason::EndTurn;
1623 let mut current_token_usage = TokenUsage::default();
1624
1625 thread
1626 .update(cx, |_thread, cx| {
1627 cx.emit(ThreadEvent::NewRequest);
1628 })
1629 .ok();
1630
1631 let mut request_assistant_message_id = None;
1632
1633 while let Some(event) = events.next().await {
1634 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1635 response_events
1636 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1637 }
1638
1639 thread.update(cx, |thread, cx| {
1640 match event? {
1641 LanguageModelCompletionEvent::StartMessage { .. } => {
1642 request_assistant_message_id =
1643 Some(thread.insert_assistant_message(
1644 vec![MessageSegment::Text(String::new())],
1645 cx,
1646 ));
1647 }
1648 LanguageModelCompletionEvent::Stop(reason) => {
1649 stop_reason = reason;
1650 }
1651 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1652 thread.update_token_usage_at_last_message(token_usage);
1653 thread.cumulative_token_usage = thread.cumulative_token_usage
1654 + token_usage
1655 - current_token_usage;
1656 current_token_usage = token_usage;
1657 }
1658 LanguageModelCompletionEvent::Text(chunk) => {
1659 thread.received_chunk();
1660
1661 cx.emit(ThreadEvent::ReceivedTextChunk);
1662 if let Some(last_message) = thread.messages.last_mut() {
1663 if last_message.role == Role::Assistant
1664 && !thread.tool_use.has_tool_results(last_message.id)
1665 {
1666 last_message.push_text(&chunk);
1667 cx.emit(ThreadEvent::StreamedAssistantText(
1668 last_message.id,
1669 chunk,
1670 ));
1671 } else {
1672 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1673 // of a new Assistant response.
1674 //
1675 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1676 // will result in duplicating the text of the chunk in the rendered Markdown.
1677 request_assistant_message_id =
1678 Some(thread.insert_assistant_message(
1679 vec![MessageSegment::Text(chunk.to_string())],
1680 cx,
1681 ));
1682 };
1683 }
1684 }
1685 LanguageModelCompletionEvent::Thinking {
1686 text: chunk,
1687 signature,
1688 } => {
1689 thread.received_chunk();
1690
1691 if let Some(last_message) = thread.messages.last_mut() {
1692 if last_message.role == Role::Assistant
1693 && !thread.tool_use.has_tool_results(last_message.id)
1694 {
1695 last_message.push_thinking(&chunk, signature);
1696 cx.emit(ThreadEvent::StreamedAssistantThinking(
1697 last_message.id,
1698 chunk,
1699 ));
1700 } else {
1701 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1702 // of a new Assistant response.
1703 //
1704 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1705 // will result in duplicating the text of the chunk in the rendered Markdown.
1706 request_assistant_message_id =
1707 Some(thread.insert_assistant_message(
1708 vec![MessageSegment::Thinking {
1709 text: chunk.to_string(),
1710 signature,
1711 }],
1712 cx,
1713 ));
1714 };
1715 }
1716 }
1717 LanguageModelCompletionEvent::RedactedThinking { data } => {
1718 thread.received_chunk();
1719
1720 if let Some(last_message) = thread.messages.last_mut() {
1721 if last_message.role == Role::Assistant
1722 && !thread.tool_use.has_tool_results(last_message.id)
1723 {
1724 last_message.push_redacted_thinking(data);
1725 } else {
1726 request_assistant_message_id =
1727 Some(thread.insert_assistant_message(
1728 vec![MessageSegment::RedactedThinking(data)],
1729 cx,
1730 ));
1731 };
1732 }
1733 }
1734 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1735 let last_assistant_message_id = request_assistant_message_id
1736 .unwrap_or_else(|| {
1737 let new_assistant_message_id =
1738 thread.insert_assistant_message(vec![], cx);
1739 request_assistant_message_id =
1740 Some(new_assistant_message_id);
1741 new_assistant_message_id
1742 });
1743
1744 let tool_use_id = tool_use.id.clone();
1745 let streamed_input = if tool_use.is_input_complete {
1746 None
1747 } else {
1748 Some((&tool_use.input).clone())
1749 };
1750
1751 let ui_text = thread.tool_use.request_tool_use(
1752 last_assistant_message_id,
1753 tool_use,
1754 tool_use_metadata.clone(),
1755 cx,
1756 );
1757
1758 if let Some(input) = streamed_input {
1759 cx.emit(ThreadEvent::StreamedToolUse {
1760 tool_use_id,
1761 ui_text,
1762 input,
1763 });
1764 }
1765 }
1766 LanguageModelCompletionEvent::ToolUseJsonParseError {
1767 id,
1768 tool_name,
1769 raw_input: invalid_input_json,
1770 json_parse_error,
1771 } => {
1772 thread.receive_invalid_tool_json(
1773 id,
1774 tool_name,
1775 invalid_input_json,
1776 json_parse_error,
1777 window,
1778 cx,
1779 );
1780 }
1781 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1782 if let Some(completion) = thread
1783 .pending_completions
1784 .iter_mut()
1785 .find(|completion| completion.id == pending_completion_id)
1786 {
1787 match status_update {
1788 CompletionRequestStatus::Queued { position } => {
1789 completion.queue_state =
1790 QueueState::Queued { position };
1791 }
1792 CompletionRequestStatus::Started => {
1793 completion.queue_state = QueueState::Started;
1794 }
1795 CompletionRequestStatus::Failed {
1796 code,
1797 message,
1798 request_id: _,
1799 retry_after,
1800 } => {
1801 return Err(
1802 LanguageModelCompletionError::from_cloud_failure(
1803 model.upstream_provider_name(),
1804 code,
1805 message,
1806 retry_after.map(Duration::from_secs_f64),
1807 ),
1808 );
1809 }
1810 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1811 thread.update_model_request_usage(
1812 amount as u32,
1813 limit,
1814 cx,
1815 );
1816 }
1817 CompletionRequestStatus::ToolUseLimitReached => {
1818 thread.tool_use_limit_reached = true;
1819 cx.emit(ThreadEvent::ToolUseLimitReached);
1820 }
1821 }
1822 }
1823 }
1824 }
1825
1826 thread.touch_updated_at();
1827 cx.emit(ThreadEvent::StreamedCompletion);
1828 cx.notify();
1829
1830 thread.auto_capture_telemetry(cx);
1831 Ok(())
1832 })??;
1833
1834 smol::future::yield_now().await;
1835 }
1836
1837 thread.update(cx, |thread, cx| {
1838 thread.last_received_chunk_at = None;
1839 thread
1840 .pending_completions
1841 .retain(|completion| completion.id != pending_completion_id);
1842
1843 // If there is a response without tool use, summarize the message. Otherwise,
1844 // allow two tool uses before summarizing.
1845 if matches!(thread.summary, ThreadSummary::Pending)
1846 && thread.messages.len() >= 2
1847 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1848 {
1849 thread.summarize(cx);
1850 }
1851 })?;
1852
1853 anyhow::Ok(stop_reason)
1854 };
1855
1856 let result = stream_completion.await;
1857 let mut retry_scheduled = false;
1858
1859 thread
1860 .update(cx, |thread, cx| {
1861 thread.finalize_pending_checkpoint(cx);
1862 match result.as_ref() {
1863 Ok(stop_reason) => {
1864 match stop_reason {
1865 StopReason::ToolUse => {
1866 let tool_uses =
1867 thread.use_pending_tools(window, model.clone(), cx);
1868 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1869 }
1870 StopReason::EndTurn | StopReason::MaxTokens => {
1871 thread.project.update(cx, |project, cx| {
1872 project.set_agent_location(None, cx);
1873 });
1874 }
1875 StopReason::Refusal => {
1876 thread.project.update(cx, |project, cx| {
1877 project.set_agent_location(None, cx);
1878 });
1879
1880 // Remove the turn that was refused.
1881 //
1882 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1883 {
1884 let mut messages_to_remove = Vec::new();
1885
1886 for (ix, message) in
1887 thread.messages.iter().enumerate().rev()
1888 {
1889 messages_to_remove.push(message.id);
1890
1891 if message.role == Role::User {
1892 if ix == 0 {
1893 break;
1894 }
1895
1896 if let Some(prev_message) =
1897 thread.messages.get(ix - 1)
1898 {
1899 if prev_message.role == Role::Assistant {
1900 break;
1901 }
1902 }
1903 }
1904 }
1905
1906 for message_id in messages_to_remove {
1907 thread.delete_message(message_id, cx);
1908 }
1909 }
1910
1911 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1912 header: "Language model refusal".into(),
1913 message:
1914 "Model refused to generate content for safety reasons."
1915 .into(),
1916 }));
1917 }
1918 }
1919
1920 // We successfully completed, so cancel any remaining retries.
1921 thread.retry_state = None;
1922 }
1923 Err(error) => {
1924 thread.project.update(cx, |project, cx| {
1925 project.set_agent_location(None, cx);
1926 });
1927
1928 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1929 let error_message = error
1930 .chain()
1931 .map(|err| err.to_string())
1932 .collect::<Vec<_>>()
1933 .join("\n");
1934 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1935 header: "Error interacting with language model".into(),
1936 message: SharedString::from(error_message.clone()),
1937 }));
1938 }
1939
1940 if error.is::<PaymentRequiredError>() {
1941 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1942 } else if let Some(error) =
1943 error.downcast_ref::<ModelRequestLimitReachedError>()
1944 {
1945 cx.emit(ThreadEvent::ShowError(
1946 ThreadError::ModelRequestLimitReached { plan: error.plan },
1947 ));
1948 } else if let Some(completion_error) =
1949 error.downcast_ref::<LanguageModelCompletionError>()
1950 {
1951 use LanguageModelCompletionError::*;
1952 match &completion_error {
1953 PromptTooLarge { tokens, .. } => {
1954 let tokens = tokens.unwrap_or_else(|| {
1955 // We didn't get an exact token count from the API, so fall back on our estimate.
1956 thread
1957 .total_token_usage()
1958 .map(|usage| usage.total)
1959 .unwrap_or(0)
1960 // We know the context window was exceeded in practice, so if our estimate was
1961 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1962 .max(model.max_token_count().saturating_add(1))
1963 });
1964 thread.exceeded_window_error = Some(ExceededWindowError {
1965 model_id: model.id(),
1966 token_count: tokens,
1967 });
1968 cx.notify();
1969 }
1970 RateLimitExceeded {
1971 retry_after: Some(retry_after),
1972 ..
1973 }
1974 | ServerOverloaded {
1975 retry_after: Some(retry_after),
1976 ..
1977 } => {
1978 thread.handle_rate_limit_error(
1979 &completion_error,
1980 *retry_after,
1981 model.clone(),
1982 intent,
1983 window,
1984 cx,
1985 );
1986 retry_scheduled = true;
1987 }
1988 RateLimitExceeded { .. } | ServerOverloaded { .. } => {
1989 retry_scheduled = thread.handle_retryable_error(
1990 &completion_error,
1991 model.clone(),
1992 intent,
1993 window,
1994 cx,
1995 );
1996 if !retry_scheduled {
1997 emit_generic_error(error, cx);
1998 }
1999 }
2000 ApiInternalServerError { .. }
2001 | ApiReadResponseError { .. }
2002 | HttpSend { .. } => {
2003 retry_scheduled = thread.handle_retryable_error(
2004 &completion_error,
2005 model.clone(),
2006 intent,
2007 window,
2008 cx,
2009 );
2010 if !retry_scheduled {
2011 emit_generic_error(error, cx);
2012 }
2013 }
2014 NoApiKey { .. }
2015 | HttpResponseError { .. }
2016 | BadRequestFormat { .. }
2017 | AuthenticationError { .. }
2018 | PermissionError { .. }
2019 | ApiEndpointNotFound { .. }
2020 | SerializeRequest { .. }
2021 | BuildRequestBody { .. }
2022 | DeserializeResponse { .. }
2023 | Other { .. } => emit_generic_error(error, cx),
2024 }
2025 } else {
2026 emit_generic_error(error, cx);
2027 }
2028
2029 if !retry_scheduled {
2030 thread.cancel_last_completion(window, cx);
2031 }
2032 }
2033 }
2034
2035 if !retry_scheduled {
2036 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2037 }
2038
2039 if let Some((request_callback, (request, response_events))) = thread
2040 .request_callback
2041 .as_mut()
2042 .zip(request_callback_parameters.as_ref())
2043 {
2044 request_callback(request, response_events);
2045 }
2046
2047 thread.auto_capture_telemetry(cx);
2048
2049 if let Ok(initial_usage) = initial_token_usage {
2050 let usage = thread.cumulative_token_usage - initial_usage;
2051
2052 telemetry::event!(
2053 "Assistant Thread Completion",
2054 thread_id = thread.id().to_string(),
2055 prompt_id = prompt_id,
2056 model = model.telemetry_id(),
2057 model_provider = model.provider_id().to_string(),
2058 input_tokens = usage.input_tokens,
2059 output_tokens = usage.output_tokens,
2060 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2061 cache_read_input_tokens = usage.cache_read_input_tokens,
2062 );
2063 }
2064 })
2065 .ok();
2066 });
2067
2068 self.pending_completions.push(PendingCompletion {
2069 id: pending_completion_id,
2070 queue_state: QueueState::Sending,
2071 _task: task,
2072 });
2073 }
2074
2075 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2076 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2077 println!("No thread summary model");
2078 return;
2079 };
2080
2081 if !model.provider.is_authenticated(cx) {
2082 return;
2083 }
2084
2085 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2086
2087 let request = self.to_summarize_request(
2088 &model.model,
2089 CompletionIntent::ThreadSummarization,
2090 added_user_message.into(),
2091 cx,
2092 );
2093
2094 self.summary = ThreadSummary::Generating;
2095
2096 self.pending_summary = cx.spawn(async move |this, cx| {
2097 let result = async {
2098 let mut messages = model.model.stream_completion(request, &cx).await?;
2099
2100 let mut new_summary = String::new();
2101 while let Some(event) = messages.next().await {
2102 let Ok(event) = event else {
2103 continue;
2104 };
2105 let text = match event {
2106 LanguageModelCompletionEvent::Text(text) => text,
2107 LanguageModelCompletionEvent::StatusUpdate(
2108 CompletionRequestStatus::UsageUpdated { amount, limit },
2109 ) => {
2110 this.update(cx, |thread, cx| {
2111 thread.update_model_request_usage(amount as u32, limit, cx);
2112 })?;
2113 continue;
2114 }
2115 _ => continue,
2116 };
2117
2118 let mut lines = text.lines();
2119 new_summary.extend(lines.next());
2120
2121 // Stop if the LLM generated multiple lines.
2122 if lines.next().is_some() {
2123 break;
2124 }
2125 }
2126
2127 anyhow::Ok(new_summary)
2128 }
2129 .await;
2130
2131 this.update(cx, |this, cx| {
2132 match result {
2133 Ok(new_summary) => {
2134 if new_summary.is_empty() {
2135 this.summary = ThreadSummary::Error;
2136 } else {
2137 this.summary = ThreadSummary::Ready(new_summary.into());
2138 }
2139 }
2140 Err(err) => {
2141 this.summary = ThreadSummary::Error;
2142 log::error!("Failed to generate thread summary: {}", err);
2143 }
2144 }
2145 cx.emit(ThreadEvent::SummaryGenerated);
2146 })
2147 .log_err()?;
2148
2149 Some(())
2150 });
2151 }
2152
2153 fn handle_rate_limit_error(
2154 &mut self,
2155 error: &LanguageModelCompletionError,
2156 retry_after: Duration,
2157 model: Arc<dyn LanguageModel>,
2158 intent: CompletionIntent,
2159 window: Option<AnyWindowHandle>,
2160 cx: &mut Context<Self>,
2161 ) {
2162 // For rate limit errors, we only retry once with the specified duration
2163 let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2164 log::warn!(
2165 "Retrying completion request in {} seconds: {error:?}",
2166 retry_after.as_secs(),
2167 );
2168
2169 // Add a UI-only message instead of a regular message
2170 let id = self.next_message_id.post_inc();
2171 self.messages.push(Message {
2172 id,
2173 role: Role::System,
2174 segments: vec![MessageSegment::Text(retry_message)],
2175 loaded_context: LoadedContext::default(),
2176 creases: Vec::new(),
2177 is_hidden: false,
2178 ui_only: true,
2179 });
2180 cx.emit(ThreadEvent::MessageAdded(id));
2181 // Schedule the retry
2182 let thread_handle = cx.entity().downgrade();
2183
2184 cx.spawn(async move |_thread, cx| {
2185 cx.background_executor().timer(retry_after).await;
2186
2187 thread_handle
2188 .update(cx, |thread, cx| {
2189 // Retry the completion
2190 thread.send_to_model(model, intent, window, cx);
2191 })
2192 .log_err();
2193 })
2194 .detach();
2195 }
2196
2197 fn handle_retryable_error(
2198 &mut self,
2199 error: &LanguageModelCompletionError,
2200 model: Arc<dyn LanguageModel>,
2201 intent: CompletionIntent,
2202 window: Option<AnyWindowHandle>,
2203 cx: &mut Context<Self>,
2204 ) -> bool {
2205 self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2206 }
2207
2208 fn handle_retryable_error_with_delay(
2209 &mut self,
2210 error: &LanguageModelCompletionError,
2211 custom_delay: Option<Duration>,
2212 model: Arc<dyn LanguageModel>,
2213 intent: CompletionIntent,
2214 window: Option<AnyWindowHandle>,
2215 cx: &mut Context<Self>,
2216 ) -> bool {
2217 let retry_state = self.retry_state.get_or_insert(RetryState {
2218 attempt: 0,
2219 max_attempts: MAX_RETRY_ATTEMPTS,
2220 intent,
2221 });
2222
2223 retry_state.attempt += 1;
2224 let attempt = retry_state.attempt;
2225 let max_attempts = retry_state.max_attempts;
2226 let intent = retry_state.intent;
2227
2228 if attempt <= max_attempts {
2229 // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2230 let delay = if let Some(custom_delay) = custom_delay {
2231 custom_delay
2232 } else {
2233 let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2234 Duration::from_secs(delay_secs)
2235 };
2236
2237 // Add a transient message to inform the user
2238 let delay_secs = delay.as_secs();
2239 let retry_message = format!(
2240 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2241 in {delay_secs} seconds..."
2242 );
2243 log::warn!(
2244 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2245 in {delay_secs} seconds: {error:?}",
2246 );
2247
2248 // Add a UI-only message instead of a regular message
2249 let id = self.next_message_id.post_inc();
2250 self.messages.push(Message {
2251 id,
2252 role: Role::System,
2253 segments: vec![MessageSegment::Text(retry_message)],
2254 loaded_context: LoadedContext::default(),
2255 creases: Vec::new(),
2256 is_hidden: false,
2257 ui_only: true,
2258 });
2259 cx.emit(ThreadEvent::MessageAdded(id));
2260
2261 // Schedule the retry
2262 let thread_handle = cx.entity().downgrade();
2263
2264 cx.spawn(async move |_thread, cx| {
2265 cx.background_executor().timer(delay).await;
2266
2267 thread_handle
2268 .update(cx, |thread, cx| {
2269 // Retry the completion
2270 thread.send_to_model(model, intent, window, cx);
2271 })
2272 .log_err();
2273 })
2274 .detach();
2275
2276 true
2277 } else {
2278 // Max retries exceeded
2279 self.retry_state = None;
2280
2281 let notification_text = if max_attempts == 1 {
2282 "Failed after retrying.".into()
2283 } else {
2284 format!("Failed after retrying {} times.", max_attempts).into()
2285 };
2286
2287 // Stop generating since we're giving up on retrying.
2288 self.pending_completions.clear();
2289
2290 cx.emit(ThreadEvent::RetriesFailed {
2291 message: notification_text,
2292 });
2293
2294 false
2295 }
2296 }
2297
2298 pub fn start_generating_detailed_summary_if_needed(
2299 &mut self,
2300 thread_store: WeakEntity<ThreadStore>,
2301 cx: &mut Context<Self>,
2302 ) {
2303 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2304 return;
2305 };
2306
2307 match &*self.detailed_summary_rx.borrow() {
2308 DetailedSummaryState::Generating { message_id, .. }
2309 | DetailedSummaryState::Generated { message_id, .. }
2310 if *message_id == last_message_id =>
2311 {
2312 // Already up-to-date
2313 return;
2314 }
2315 _ => {}
2316 }
2317
2318 let Some(ConfiguredModel { model, provider }) =
2319 LanguageModelRegistry::read_global(cx).thread_summary_model()
2320 else {
2321 return;
2322 };
2323
2324 if !provider.is_authenticated(cx) {
2325 return;
2326 }
2327
2328 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2329
2330 let request = self.to_summarize_request(
2331 &model,
2332 CompletionIntent::ThreadContextSummarization,
2333 added_user_message.into(),
2334 cx,
2335 );
2336
2337 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2338 message_id: last_message_id,
2339 };
2340
2341 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2342 // be better to allow the old task to complete, but this would require logic for choosing
2343 // which result to prefer (the old task could complete after the new one, resulting in a
2344 // stale summary).
2345 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2346 let stream = model.stream_completion_text(request, &cx);
2347 let Some(mut messages) = stream.await.log_err() else {
2348 thread
2349 .update(cx, |thread, _cx| {
2350 *thread.detailed_summary_tx.borrow_mut() =
2351 DetailedSummaryState::NotGenerated;
2352 })
2353 .ok()?;
2354 return None;
2355 };
2356
2357 let mut new_detailed_summary = String::new();
2358
2359 while let Some(chunk) = messages.stream.next().await {
2360 if let Some(chunk) = chunk.log_err() {
2361 new_detailed_summary.push_str(&chunk);
2362 }
2363 }
2364
2365 thread
2366 .update(cx, |thread, _cx| {
2367 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2368 text: new_detailed_summary.into(),
2369 message_id: last_message_id,
2370 };
2371 })
2372 .ok()?;
2373
2374 // Save thread so its summary can be reused later
2375 if let Some(thread) = thread.upgrade() {
2376 if let Ok(Ok(save_task)) = cx.update(|cx| {
2377 thread_store
2378 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2379 }) {
2380 save_task.await.log_err();
2381 }
2382 }
2383
2384 Some(())
2385 });
2386 }
2387
2388 pub async fn wait_for_detailed_summary_or_text(
2389 this: &Entity<Self>,
2390 cx: &mut AsyncApp,
2391 ) -> Option<SharedString> {
2392 let mut detailed_summary_rx = this
2393 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2394 .ok()?;
2395 loop {
2396 match detailed_summary_rx.recv().await? {
2397 DetailedSummaryState::Generating { .. } => {}
2398 DetailedSummaryState::NotGenerated => {
2399 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2400 }
2401 DetailedSummaryState::Generated { text, .. } => return Some(text),
2402 }
2403 }
2404 }
2405
2406 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2407 self.detailed_summary_rx
2408 .borrow()
2409 .text()
2410 .unwrap_or_else(|| self.text().into())
2411 }
2412
2413 pub fn is_generating_detailed_summary(&self) -> bool {
2414 matches!(
2415 &*self.detailed_summary_rx.borrow(),
2416 DetailedSummaryState::Generating { .. }
2417 )
2418 }
2419
2420 pub fn use_pending_tools(
2421 &mut self,
2422 window: Option<AnyWindowHandle>,
2423 model: Arc<dyn LanguageModel>,
2424 cx: &mut Context<Self>,
2425 ) -> Vec<PendingToolUse> {
2426 self.auto_capture_telemetry(cx);
2427 let request =
2428 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2429 let pending_tool_uses = self
2430 .tool_use
2431 .pending_tool_uses()
2432 .into_iter()
2433 .filter(|tool_use| tool_use.status.is_idle())
2434 .cloned()
2435 .collect::<Vec<_>>();
2436
2437 for tool_use in pending_tool_uses.iter() {
2438 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2439 }
2440
2441 pending_tool_uses
2442 }
2443
2444 fn use_pending_tool(
2445 &mut self,
2446 tool_use: PendingToolUse,
2447 request: Arc<LanguageModelRequest>,
2448 model: Arc<dyn LanguageModel>,
2449 window: Option<AnyWindowHandle>,
2450 cx: &mut Context<Self>,
2451 ) {
2452 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2453 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2454 };
2455
2456 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2457 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2458 }
2459
2460 if tool.needs_confirmation(&tool_use.input, cx)
2461 && !AgentSettings::get_global(cx).always_allow_tool_actions
2462 {
2463 self.tool_use.confirm_tool_use(
2464 tool_use.id,
2465 tool_use.ui_text,
2466 tool_use.input,
2467 request,
2468 tool,
2469 );
2470 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2471 } else {
2472 self.run_tool(
2473 tool_use.id,
2474 tool_use.ui_text,
2475 tool_use.input,
2476 request,
2477 tool,
2478 model,
2479 window,
2480 cx,
2481 );
2482 }
2483 }
2484
2485 pub fn handle_hallucinated_tool_use(
2486 &mut self,
2487 tool_use_id: LanguageModelToolUseId,
2488 hallucinated_tool_name: Arc<str>,
2489 window: Option<AnyWindowHandle>,
2490 cx: &mut Context<Thread>,
2491 ) {
2492 let available_tools = self.profile.enabled_tools(cx);
2493
2494 let tool_list = available_tools
2495 .iter()
2496 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2497 .collect::<Vec<_>>()
2498 .join("\n");
2499
2500 let error_message = format!(
2501 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2502 hallucinated_tool_name, tool_list
2503 );
2504
2505 let pending_tool_use = self.tool_use.insert_tool_output(
2506 tool_use_id.clone(),
2507 hallucinated_tool_name,
2508 Err(anyhow!("Missing tool call: {error_message}")),
2509 self.configured_model.as_ref(),
2510 );
2511
2512 cx.emit(ThreadEvent::MissingToolUse {
2513 tool_use_id: tool_use_id.clone(),
2514 ui_text: error_message.into(),
2515 });
2516
2517 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2518 }
2519
2520 pub fn receive_invalid_tool_json(
2521 &mut self,
2522 tool_use_id: LanguageModelToolUseId,
2523 tool_name: Arc<str>,
2524 invalid_json: Arc<str>,
2525 error: String,
2526 window: Option<AnyWindowHandle>,
2527 cx: &mut Context<Thread>,
2528 ) {
2529 log::error!("The model returned invalid input JSON: {invalid_json}");
2530
2531 let pending_tool_use = self.tool_use.insert_tool_output(
2532 tool_use_id.clone(),
2533 tool_name,
2534 Err(anyhow!("Error parsing input JSON: {error}")),
2535 self.configured_model.as_ref(),
2536 );
2537 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2538 pending_tool_use.ui_text.clone()
2539 } else {
2540 log::error!(
2541 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2542 );
2543 format!("Unknown tool {}", tool_use_id).into()
2544 };
2545
2546 cx.emit(ThreadEvent::InvalidToolInput {
2547 tool_use_id: tool_use_id.clone(),
2548 ui_text,
2549 invalid_input_json: invalid_json,
2550 });
2551
2552 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2553 }
2554
2555 pub fn run_tool(
2556 &mut self,
2557 tool_use_id: LanguageModelToolUseId,
2558 ui_text: impl Into<SharedString>,
2559 input: serde_json::Value,
2560 request: Arc<LanguageModelRequest>,
2561 tool: Arc<dyn Tool>,
2562 model: Arc<dyn LanguageModel>,
2563 window: Option<AnyWindowHandle>,
2564 cx: &mut Context<Thread>,
2565 ) {
2566 let task =
2567 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2568 self.tool_use
2569 .run_pending_tool(tool_use_id, ui_text.into(), task);
2570 }
2571
2572 fn spawn_tool_use(
2573 &mut self,
2574 tool_use_id: LanguageModelToolUseId,
2575 request: Arc<LanguageModelRequest>,
2576 input: serde_json::Value,
2577 tool: Arc<dyn Tool>,
2578 model: Arc<dyn LanguageModel>,
2579 window: Option<AnyWindowHandle>,
2580 cx: &mut Context<Thread>,
2581 ) -> Task<()> {
2582 let tool_name: Arc<str> = tool.name().into();
2583
2584 let tool_result = tool.run(
2585 input,
2586 request,
2587 self.project.clone(),
2588 self.action_log.clone(),
2589 model,
2590 window,
2591 cx,
2592 );
2593
2594 // Store the card separately if it exists
2595 if let Some(card) = tool_result.card.clone() {
2596 self.tool_use
2597 .insert_tool_result_card(tool_use_id.clone(), card);
2598 }
2599
2600 cx.spawn({
2601 async move |thread: WeakEntity<Thread>, cx| {
2602 let output = tool_result.output.await;
2603
2604 thread
2605 .update(cx, |thread, cx| {
2606 let pending_tool_use = thread.tool_use.insert_tool_output(
2607 tool_use_id.clone(),
2608 tool_name,
2609 output,
2610 thread.configured_model.as_ref(),
2611 );
2612 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2613 })
2614 .ok();
2615 }
2616 })
2617 }
2618
2619 fn tool_finished(
2620 &mut self,
2621 tool_use_id: LanguageModelToolUseId,
2622 pending_tool_use: Option<PendingToolUse>,
2623 canceled: bool,
2624 window: Option<AnyWindowHandle>,
2625 cx: &mut Context<Self>,
2626 ) {
2627 if self.all_tools_finished() {
2628 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2629 if !canceled {
2630 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2631 }
2632 self.auto_capture_telemetry(cx);
2633 }
2634 }
2635
2636 cx.emit(ThreadEvent::ToolFinished {
2637 tool_use_id,
2638 pending_tool_use,
2639 });
2640 }
2641
2642 /// Cancels the last pending completion, if there are any pending.
2643 ///
2644 /// Returns whether a completion was canceled.
2645 pub fn cancel_last_completion(
2646 &mut self,
2647 window: Option<AnyWindowHandle>,
2648 cx: &mut Context<Self>,
2649 ) -> bool {
2650 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2651
2652 self.retry_state = None;
2653
2654 for pending_tool_use in self.tool_use.cancel_pending() {
2655 canceled = true;
2656 self.tool_finished(
2657 pending_tool_use.id.clone(),
2658 Some(pending_tool_use),
2659 true,
2660 window,
2661 cx,
2662 );
2663 }
2664
2665 if canceled {
2666 cx.emit(ThreadEvent::CompletionCanceled);
2667
2668 // When canceled, we always want to insert the checkpoint.
2669 // (We skip over finalize_pending_checkpoint, because it
2670 // would conclude we didn't have anything to insert here.)
2671 if let Some(checkpoint) = self.pending_checkpoint.take() {
2672 self.insert_checkpoint(checkpoint, cx);
2673 }
2674 } else {
2675 self.finalize_pending_checkpoint(cx);
2676 }
2677
2678 canceled
2679 }
2680
2681 /// Signals that any in-progress editing should be canceled.
2682 ///
2683 /// This method is used to notify listeners (like ActiveThread) that
2684 /// they should cancel any editing operations.
2685 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2686 cx.emit(ThreadEvent::CancelEditing);
2687 }
2688
2689 pub fn feedback(&self) -> Option<ThreadFeedback> {
2690 self.feedback
2691 }
2692
2693 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2694 self.message_feedback.get(&message_id).copied()
2695 }
2696
2697 pub fn report_message_feedback(
2698 &mut self,
2699 message_id: MessageId,
2700 feedback: ThreadFeedback,
2701 cx: &mut Context<Self>,
2702 ) -> Task<Result<()>> {
2703 if self.message_feedback.get(&message_id) == Some(&feedback) {
2704 return Task::ready(Ok(()));
2705 }
2706
2707 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2708 let serialized_thread = self.serialize(cx);
2709 let thread_id = self.id().clone();
2710 let client = self.project.read(cx).client();
2711
2712 let enabled_tool_names: Vec<String> = self
2713 .profile
2714 .enabled_tools(cx)
2715 .iter()
2716 .map(|(name, _)| name.clone().into())
2717 .collect();
2718
2719 self.message_feedback.insert(message_id, feedback);
2720
2721 cx.notify();
2722
2723 let message_content = self
2724 .message(message_id)
2725 .map(|msg| msg.to_string())
2726 .unwrap_or_default();
2727
2728 cx.background_spawn(async move {
2729 let final_project_snapshot = final_project_snapshot.await;
2730 let serialized_thread = serialized_thread.await?;
2731 let thread_data =
2732 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2733
2734 let rating = match feedback {
2735 ThreadFeedback::Positive => "positive",
2736 ThreadFeedback::Negative => "negative",
2737 };
2738 telemetry::event!(
2739 "Assistant Thread Rated",
2740 rating,
2741 thread_id,
2742 enabled_tool_names,
2743 message_id = message_id.0,
2744 message_content,
2745 thread_data,
2746 final_project_snapshot
2747 );
2748 client.telemetry().flush_events().await;
2749
2750 Ok(())
2751 })
2752 }
2753
2754 pub fn report_feedback(
2755 &mut self,
2756 feedback: ThreadFeedback,
2757 cx: &mut Context<Self>,
2758 ) -> Task<Result<()>> {
2759 let last_assistant_message_id = self
2760 .messages
2761 .iter()
2762 .rev()
2763 .find(|msg| msg.role == Role::Assistant)
2764 .map(|msg| msg.id);
2765
2766 if let Some(message_id) = last_assistant_message_id {
2767 self.report_message_feedback(message_id, feedback, cx)
2768 } else {
2769 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2770 let serialized_thread = self.serialize(cx);
2771 let thread_id = self.id().clone();
2772 let client = self.project.read(cx).client();
2773 self.feedback = Some(feedback);
2774 cx.notify();
2775
2776 cx.background_spawn(async move {
2777 let final_project_snapshot = final_project_snapshot.await;
2778 let serialized_thread = serialized_thread.await?;
2779 let thread_data = serde_json::to_value(serialized_thread)
2780 .unwrap_or_else(|_| serde_json::Value::Null);
2781
2782 let rating = match feedback {
2783 ThreadFeedback::Positive => "positive",
2784 ThreadFeedback::Negative => "negative",
2785 };
2786 telemetry::event!(
2787 "Assistant Thread Rated",
2788 rating,
2789 thread_id,
2790 thread_data,
2791 final_project_snapshot
2792 );
2793 client.telemetry().flush_events().await;
2794
2795 Ok(())
2796 })
2797 }
2798 }
2799
2800 /// Create a snapshot of the current project state including git information and unsaved buffers.
2801 fn project_snapshot(
2802 project: Entity<Project>,
2803 cx: &mut Context<Self>,
2804 ) -> Task<Arc<ProjectSnapshot>> {
2805 let git_store = project.read(cx).git_store().clone();
2806 let worktree_snapshots: Vec<_> = project
2807 .read(cx)
2808 .visible_worktrees(cx)
2809 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2810 .collect();
2811
2812 cx.spawn(async move |_, cx| {
2813 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2814
2815 let mut unsaved_buffers = Vec::new();
2816 cx.update(|app_cx| {
2817 let buffer_store = project.read(app_cx).buffer_store();
2818 for buffer_handle in buffer_store.read(app_cx).buffers() {
2819 let buffer = buffer_handle.read(app_cx);
2820 if buffer.is_dirty() {
2821 if let Some(file) = buffer.file() {
2822 let path = file.path().to_string_lossy().to_string();
2823 unsaved_buffers.push(path);
2824 }
2825 }
2826 }
2827 })
2828 .ok();
2829
2830 Arc::new(ProjectSnapshot {
2831 worktree_snapshots,
2832 unsaved_buffer_paths: unsaved_buffers,
2833 timestamp: Utc::now(),
2834 })
2835 })
2836 }
2837
2838 fn worktree_snapshot(
2839 worktree: Entity<project::Worktree>,
2840 git_store: Entity<GitStore>,
2841 cx: &App,
2842 ) -> Task<WorktreeSnapshot> {
2843 cx.spawn(async move |cx| {
2844 // Get worktree path and snapshot
2845 let worktree_info = cx.update(|app_cx| {
2846 let worktree = worktree.read(app_cx);
2847 let path = worktree.abs_path().to_string_lossy().to_string();
2848 let snapshot = worktree.snapshot();
2849 (path, snapshot)
2850 });
2851
2852 let Ok((worktree_path, _snapshot)) = worktree_info else {
2853 return WorktreeSnapshot {
2854 worktree_path: String::new(),
2855 git_state: None,
2856 };
2857 };
2858
2859 let git_state = git_store
2860 .update(cx, |git_store, cx| {
2861 git_store
2862 .repositories()
2863 .values()
2864 .find(|repo| {
2865 repo.read(cx)
2866 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2867 .is_some()
2868 })
2869 .cloned()
2870 })
2871 .ok()
2872 .flatten()
2873 .map(|repo| {
2874 repo.update(cx, |repo, _| {
2875 let current_branch =
2876 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2877 repo.send_job(None, |state, _| async move {
2878 let RepositoryState::Local { backend, .. } = state else {
2879 return GitState {
2880 remote_url: None,
2881 head_sha: None,
2882 current_branch,
2883 diff: None,
2884 };
2885 };
2886
2887 let remote_url = backend.remote_url("origin");
2888 let head_sha = backend.head_sha().await;
2889 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2890
2891 GitState {
2892 remote_url,
2893 head_sha,
2894 current_branch,
2895 diff,
2896 }
2897 })
2898 })
2899 });
2900
2901 let git_state = match git_state {
2902 Some(git_state) => match git_state.ok() {
2903 Some(git_state) => git_state.await.ok(),
2904 None => None,
2905 },
2906 None => None,
2907 };
2908
2909 WorktreeSnapshot {
2910 worktree_path,
2911 git_state,
2912 }
2913 })
2914 }
2915
2916 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2917 let mut markdown = Vec::new();
2918
2919 let summary = self.summary().or_default();
2920 writeln!(markdown, "# {summary}\n")?;
2921
2922 for message in self.messages() {
2923 writeln!(
2924 markdown,
2925 "## {role}\n",
2926 role = match message.role {
2927 Role::User => "User",
2928 Role::Assistant => "Agent",
2929 Role::System => "System",
2930 }
2931 )?;
2932
2933 if !message.loaded_context.text.is_empty() {
2934 writeln!(markdown, "{}", message.loaded_context.text)?;
2935 }
2936
2937 if !message.loaded_context.images.is_empty() {
2938 writeln!(
2939 markdown,
2940 "\n{} images attached as context.\n",
2941 message.loaded_context.images.len()
2942 )?;
2943 }
2944
2945 for segment in &message.segments {
2946 match segment {
2947 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2948 MessageSegment::Thinking { text, .. } => {
2949 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2950 }
2951 MessageSegment::RedactedThinking(_) => {}
2952 }
2953 }
2954
2955 for tool_use in self.tool_uses_for_message(message.id, cx) {
2956 writeln!(
2957 markdown,
2958 "**Use Tool: {} ({})**",
2959 tool_use.name, tool_use.id
2960 )?;
2961 writeln!(markdown, "```json")?;
2962 writeln!(
2963 markdown,
2964 "{}",
2965 serde_json::to_string_pretty(&tool_use.input)?
2966 )?;
2967 writeln!(markdown, "```")?;
2968 }
2969
2970 for tool_result in self.tool_results_for_message(message.id) {
2971 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2972 if tool_result.is_error {
2973 write!(markdown, " (Error)")?;
2974 }
2975
2976 writeln!(markdown, "**\n")?;
2977 match &tool_result.content {
2978 LanguageModelToolResultContent::Text(text) => {
2979 writeln!(markdown, "{text}")?;
2980 }
2981 LanguageModelToolResultContent::Image(image) => {
2982 writeln!(markdown, "", image.source)?;
2983 }
2984 }
2985
2986 if let Some(output) = tool_result.output.as_ref() {
2987 writeln!(
2988 markdown,
2989 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2990 serde_json::to_string_pretty(output)?
2991 )?;
2992 }
2993 }
2994 }
2995
2996 Ok(String::from_utf8_lossy(&markdown).to_string())
2997 }
2998
2999 pub fn keep_edits_in_range(
3000 &mut self,
3001 buffer: Entity<language::Buffer>,
3002 buffer_range: Range<language::Anchor>,
3003 cx: &mut Context<Self>,
3004 ) {
3005 self.action_log.update(cx, |action_log, cx| {
3006 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3007 });
3008 }
3009
3010 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3011 self.action_log
3012 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3013 }
3014
3015 pub fn reject_edits_in_ranges(
3016 &mut self,
3017 buffer: Entity<language::Buffer>,
3018 buffer_ranges: Vec<Range<language::Anchor>>,
3019 cx: &mut Context<Self>,
3020 ) -> Task<Result<()>> {
3021 self.action_log.update(cx, |action_log, cx| {
3022 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3023 })
3024 }
3025
3026 pub fn action_log(&self) -> &Entity<ActionLog> {
3027 &self.action_log
3028 }
3029
3030 pub fn project(&self) -> &Entity<Project> {
3031 &self.project
3032 }
3033
3034 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3035 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3036 return;
3037 }
3038
3039 let now = Instant::now();
3040 if let Some(last) = self.last_auto_capture_at {
3041 if now.duration_since(last).as_secs() < 10 {
3042 return;
3043 }
3044 }
3045
3046 self.last_auto_capture_at = Some(now);
3047
3048 let thread_id = self.id().clone();
3049 let github_login = self
3050 .project
3051 .read(cx)
3052 .user_store()
3053 .read(cx)
3054 .current_user()
3055 .map(|user| user.github_login.clone());
3056 let client = self.project.read(cx).client();
3057 let serialize_task = self.serialize(cx);
3058
3059 cx.background_executor()
3060 .spawn(async move {
3061 if let Ok(serialized_thread) = serialize_task.await {
3062 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3063 telemetry::event!(
3064 "Agent Thread Auto-Captured",
3065 thread_id = thread_id.to_string(),
3066 thread_data = thread_data,
3067 auto_capture_reason = "tracked_user",
3068 github_login = github_login
3069 );
3070
3071 client.telemetry().flush_events().await;
3072 }
3073 }
3074 })
3075 .detach();
3076 }
3077
3078 pub fn cumulative_token_usage(&self) -> TokenUsage {
3079 self.cumulative_token_usage
3080 }
3081
3082 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3083 let Some(model) = self.configured_model.as_ref() else {
3084 return TotalTokenUsage::default();
3085 };
3086
3087 let max = model.model.max_token_count();
3088
3089 let index = self
3090 .messages
3091 .iter()
3092 .position(|msg| msg.id == message_id)
3093 .unwrap_or(0);
3094
3095 if index == 0 {
3096 return TotalTokenUsage { total: 0, max };
3097 }
3098
3099 let token_usage = &self
3100 .request_token_usage
3101 .get(index - 1)
3102 .cloned()
3103 .unwrap_or_default();
3104
3105 TotalTokenUsage {
3106 total: token_usage.total_tokens(),
3107 max,
3108 }
3109 }
3110
3111 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3112 let model = self.configured_model.as_ref()?;
3113
3114 let max = model.model.max_token_count();
3115
3116 if let Some(exceeded_error) = &self.exceeded_window_error {
3117 if model.model.id() == exceeded_error.model_id {
3118 return Some(TotalTokenUsage {
3119 total: exceeded_error.token_count,
3120 max,
3121 });
3122 }
3123 }
3124
3125 let total = self
3126 .token_usage_at_last_message()
3127 .unwrap_or_default()
3128 .total_tokens();
3129
3130 Some(TotalTokenUsage { total, max })
3131 }
3132
3133 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3134 self.request_token_usage
3135 .get(self.messages.len().saturating_sub(1))
3136 .or_else(|| self.request_token_usage.last())
3137 .cloned()
3138 }
3139
3140 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3141 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3142 self.request_token_usage
3143 .resize(self.messages.len(), placeholder);
3144
3145 if let Some(last) = self.request_token_usage.last_mut() {
3146 *last = token_usage;
3147 }
3148 }
3149
3150 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3151 self.project.update(cx, |project, cx| {
3152 project.user_store().update(cx, |user_store, cx| {
3153 user_store.update_model_request_usage(
3154 ModelRequestUsage(RequestUsage {
3155 amount: amount as i32,
3156 limit,
3157 }),
3158 cx,
3159 )
3160 })
3161 });
3162 }
3163
3164 pub fn deny_tool_use(
3165 &mut self,
3166 tool_use_id: LanguageModelToolUseId,
3167 tool_name: Arc<str>,
3168 window: Option<AnyWindowHandle>,
3169 cx: &mut Context<Self>,
3170 ) {
3171 let err = Err(anyhow::anyhow!(
3172 "Permission to run tool action denied by user"
3173 ));
3174
3175 self.tool_use.insert_tool_output(
3176 tool_use_id.clone(),
3177 tool_name,
3178 err,
3179 self.configured_model.as_ref(),
3180 );
3181 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3182 }
3183}
3184
3185#[derive(Debug, Clone, Error)]
3186pub enum ThreadError {
3187 #[error("Payment required")]
3188 PaymentRequired,
3189 #[error("Model request limit reached")]
3190 ModelRequestLimitReached { plan: Plan },
3191 #[error("Message {header}: {message}")]
3192 Message {
3193 header: SharedString,
3194 message: SharedString,
3195 },
3196}
3197
3198#[derive(Debug, Clone)]
3199pub enum ThreadEvent {
3200 ShowError(ThreadError),
3201 StreamedCompletion,
3202 ReceivedTextChunk,
3203 NewRequest,
3204 StreamedAssistantText(MessageId, String),
3205 StreamedAssistantThinking(MessageId, String),
3206 StreamedToolUse {
3207 tool_use_id: LanguageModelToolUseId,
3208 ui_text: Arc<str>,
3209 input: serde_json::Value,
3210 },
3211 MissingToolUse {
3212 tool_use_id: LanguageModelToolUseId,
3213 ui_text: Arc<str>,
3214 },
3215 InvalidToolInput {
3216 tool_use_id: LanguageModelToolUseId,
3217 ui_text: Arc<str>,
3218 invalid_input_json: Arc<str>,
3219 },
3220 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3221 MessageAdded(MessageId),
3222 MessageEdited(MessageId),
3223 MessageDeleted(MessageId),
3224 SummaryGenerated,
3225 SummaryChanged,
3226 UsePendingTools {
3227 tool_uses: Vec<PendingToolUse>,
3228 },
3229 ToolFinished {
3230 #[allow(unused)]
3231 tool_use_id: LanguageModelToolUseId,
3232 /// The pending tool use that corresponds to this tool.
3233 pending_tool_use: Option<PendingToolUse>,
3234 },
3235 CheckpointChanged,
3236 ToolConfirmationNeeded,
3237 ToolUseLimitReached,
3238 CancelEditing,
3239 CompletionCanceled,
3240 ProfileChanged,
3241 RetriesFailed {
3242 message: SharedString,
3243 },
3244}
3245
3246impl EventEmitter<ThreadEvent> for Thread {}
3247
3248struct PendingCompletion {
3249 id: usize,
3250 queue_state: QueueState,
3251 _task: Task<()>,
3252}
3253
3254#[cfg(test)]
3255mod tests {
3256 use super::*;
3257 use crate::{
3258 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3259 };
3260
3261 // Test-specific constants
3262 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3263 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3264 use assistant_tool::ToolRegistry;
3265 use assistant_tools;
3266 use futures::StreamExt;
3267 use futures::future::BoxFuture;
3268 use futures::stream::BoxStream;
3269 use gpui::TestAppContext;
3270 use http_client;
3271 use indoc::indoc;
3272 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3273 use language_model::{
3274 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3275 LanguageModelProviderName, LanguageModelToolChoice,
3276 };
3277 use parking_lot::Mutex;
3278 use project::{FakeFs, Project};
3279 use prompt_store::PromptBuilder;
3280 use serde_json::json;
3281 use settings::{Settings, SettingsStore};
3282 use std::sync::Arc;
3283 use std::time::Duration;
3284 use theme::ThemeSettings;
3285 use util::path;
3286 use workspace::Workspace;
3287
3288 #[gpui::test]
3289 async fn test_message_with_context(cx: &mut TestAppContext) {
3290 init_test_settings(cx);
3291
3292 let project = create_test_project(
3293 cx,
3294 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3295 )
3296 .await;
3297
3298 let (_workspace, _thread_store, thread, context_store, model) =
3299 setup_test_environment(cx, project.clone()).await;
3300
3301 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3302 .await
3303 .unwrap();
3304
3305 let context =
3306 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3307 let loaded_context = cx
3308 .update(|cx| load_context(vec![context], &project, &None, cx))
3309 .await;
3310
3311 // Insert user message with context
3312 let message_id = thread.update(cx, |thread, cx| {
3313 thread.insert_user_message(
3314 "Please explain this code",
3315 loaded_context,
3316 None,
3317 Vec::new(),
3318 cx,
3319 )
3320 });
3321
3322 // Check content and context in message object
3323 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3324
3325 // Use different path format strings based on platform for the test
3326 #[cfg(windows)]
3327 let path_part = r"test\code.rs";
3328 #[cfg(not(windows))]
3329 let path_part = "test/code.rs";
3330
3331 let expected_context = format!(
3332 r#"
3333<context>
3334The following items were attached by the user. They are up-to-date and don't need to be re-read.
3335
3336<files>
3337```rs {path_part}
3338fn main() {{
3339 println!("Hello, world!");
3340}}
3341```
3342</files>
3343</context>
3344"#
3345 );
3346
3347 assert_eq!(message.role, Role::User);
3348 assert_eq!(message.segments.len(), 1);
3349 assert_eq!(
3350 message.segments[0],
3351 MessageSegment::Text("Please explain this code".to_string())
3352 );
3353 assert_eq!(message.loaded_context.text, expected_context);
3354
3355 // Check message in request
3356 let request = thread.update(cx, |thread, cx| {
3357 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3358 });
3359
3360 assert_eq!(request.messages.len(), 2);
3361 let expected_full_message = format!("{}Please explain this code", expected_context);
3362 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3363 }
3364
3365 #[gpui::test]
3366 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3367 init_test_settings(cx);
3368
3369 let project = create_test_project(
3370 cx,
3371 json!({
3372 "file1.rs": "fn function1() {}\n",
3373 "file2.rs": "fn function2() {}\n",
3374 "file3.rs": "fn function3() {}\n",
3375 "file4.rs": "fn function4() {}\n",
3376 }),
3377 )
3378 .await;
3379
3380 let (_, _thread_store, thread, context_store, model) =
3381 setup_test_environment(cx, project.clone()).await;
3382
3383 // First message with context 1
3384 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3385 .await
3386 .unwrap();
3387 let new_contexts = context_store.update(cx, |store, cx| {
3388 store.new_context_for_thread(thread.read(cx), None)
3389 });
3390 assert_eq!(new_contexts.len(), 1);
3391 let loaded_context = cx
3392 .update(|cx| load_context(new_contexts, &project, &None, cx))
3393 .await;
3394 let message1_id = thread.update(cx, |thread, cx| {
3395 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3396 });
3397
3398 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3399 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3400 .await
3401 .unwrap();
3402 let new_contexts = context_store.update(cx, |store, cx| {
3403 store.new_context_for_thread(thread.read(cx), None)
3404 });
3405 assert_eq!(new_contexts.len(), 1);
3406 let loaded_context = cx
3407 .update(|cx| load_context(new_contexts, &project, &None, cx))
3408 .await;
3409 let message2_id = thread.update(cx, |thread, cx| {
3410 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3411 });
3412
3413 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3414 //
3415 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3416 .await
3417 .unwrap();
3418 let new_contexts = context_store.update(cx, |store, cx| {
3419 store.new_context_for_thread(thread.read(cx), None)
3420 });
3421 assert_eq!(new_contexts.len(), 1);
3422 let loaded_context = cx
3423 .update(|cx| load_context(new_contexts, &project, &None, cx))
3424 .await;
3425 let message3_id = thread.update(cx, |thread, cx| {
3426 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3427 });
3428
3429 // Check what contexts are included in each message
3430 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3431 (
3432 thread.message(message1_id).unwrap().clone(),
3433 thread.message(message2_id).unwrap().clone(),
3434 thread.message(message3_id).unwrap().clone(),
3435 )
3436 });
3437
3438 // First message should include context 1
3439 assert!(message1.loaded_context.text.contains("file1.rs"));
3440
3441 // Second message should include only context 2 (not 1)
3442 assert!(!message2.loaded_context.text.contains("file1.rs"));
3443 assert!(message2.loaded_context.text.contains("file2.rs"));
3444
3445 // Third message should include only context 3 (not 1 or 2)
3446 assert!(!message3.loaded_context.text.contains("file1.rs"));
3447 assert!(!message3.loaded_context.text.contains("file2.rs"));
3448 assert!(message3.loaded_context.text.contains("file3.rs"));
3449
3450 // Check entire request to make sure all contexts are properly included
3451 let request = thread.update(cx, |thread, cx| {
3452 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3453 });
3454
3455 // The request should contain all 3 messages
3456 assert_eq!(request.messages.len(), 4);
3457
3458 // Check that the contexts are properly formatted in each message
3459 assert!(request.messages[1].string_contents().contains("file1.rs"));
3460 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3461 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3462
3463 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3464 assert!(request.messages[2].string_contents().contains("file2.rs"));
3465 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3466
3467 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3468 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3469 assert!(request.messages[3].string_contents().contains("file3.rs"));
3470
3471 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3472 .await
3473 .unwrap();
3474 let new_contexts = context_store.update(cx, |store, cx| {
3475 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3476 });
3477 assert_eq!(new_contexts.len(), 3);
3478 let loaded_context = cx
3479 .update(|cx| load_context(new_contexts, &project, &None, cx))
3480 .await
3481 .loaded_context;
3482
3483 assert!(!loaded_context.text.contains("file1.rs"));
3484 assert!(loaded_context.text.contains("file2.rs"));
3485 assert!(loaded_context.text.contains("file3.rs"));
3486 assert!(loaded_context.text.contains("file4.rs"));
3487
3488 let new_contexts = context_store.update(cx, |store, cx| {
3489 // Remove file4.rs
3490 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3491 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3492 });
3493 assert_eq!(new_contexts.len(), 2);
3494 let loaded_context = cx
3495 .update(|cx| load_context(new_contexts, &project, &None, cx))
3496 .await
3497 .loaded_context;
3498
3499 assert!(!loaded_context.text.contains("file1.rs"));
3500 assert!(loaded_context.text.contains("file2.rs"));
3501 assert!(loaded_context.text.contains("file3.rs"));
3502 assert!(!loaded_context.text.contains("file4.rs"));
3503
3504 let new_contexts = context_store.update(cx, |store, cx| {
3505 // Remove file3.rs
3506 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3507 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3508 });
3509 assert_eq!(new_contexts.len(), 1);
3510 let loaded_context = cx
3511 .update(|cx| load_context(new_contexts, &project, &None, cx))
3512 .await
3513 .loaded_context;
3514
3515 assert!(!loaded_context.text.contains("file1.rs"));
3516 assert!(loaded_context.text.contains("file2.rs"));
3517 assert!(!loaded_context.text.contains("file3.rs"));
3518 assert!(!loaded_context.text.contains("file4.rs"));
3519 }
3520
3521 #[gpui::test]
3522 async fn test_message_without_files(cx: &mut TestAppContext) {
3523 init_test_settings(cx);
3524
3525 let project = create_test_project(
3526 cx,
3527 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3528 )
3529 .await;
3530
3531 let (_, _thread_store, thread, _context_store, model) =
3532 setup_test_environment(cx, project.clone()).await;
3533
3534 // Insert user message without any context (empty context vector)
3535 let message_id = thread.update(cx, |thread, cx| {
3536 thread.insert_user_message(
3537 "What is the best way to learn Rust?",
3538 ContextLoadResult::default(),
3539 None,
3540 Vec::new(),
3541 cx,
3542 )
3543 });
3544
3545 // Check content and context in message object
3546 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3547
3548 // Context should be empty when no files are included
3549 assert_eq!(message.role, Role::User);
3550 assert_eq!(message.segments.len(), 1);
3551 assert_eq!(
3552 message.segments[0],
3553 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3554 );
3555 assert_eq!(message.loaded_context.text, "");
3556
3557 // Check message in request
3558 let request = thread.update(cx, |thread, cx| {
3559 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3560 });
3561
3562 assert_eq!(request.messages.len(), 2);
3563 assert_eq!(
3564 request.messages[1].string_contents(),
3565 "What is the best way to learn Rust?"
3566 );
3567
3568 // Add second message, also without context
3569 let message2_id = thread.update(cx, |thread, cx| {
3570 thread.insert_user_message(
3571 "Are there any good books?",
3572 ContextLoadResult::default(),
3573 None,
3574 Vec::new(),
3575 cx,
3576 )
3577 });
3578
3579 let message2 =
3580 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3581 assert_eq!(message2.loaded_context.text, "");
3582
3583 // Check that both messages appear in the request
3584 let request = thread.update(cx, |thread, cx| {
3585 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3586 });
3587
3588 assert_eq!(request.messages.len(), 3);
3589 assert_eq!(
3590 request.messages[1].string_contents(),
3591 "What is the best way to learn Rust?"
3592 );
3593 assert_eq!(
3594 request.messages[2].string_contents(),
3595 "Are there any good books?"
3596 );
3597 }
3598
3599 #[gpui::test]
3600 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3601 init_test_settings(cx);
3602
3603 let project = create_test_project(
3604 cx,
3605 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3606 )
3607 .await;
3608
3609 let (_workspace, _thread_store, thread, context_store, model) =
3610 setup_test_environment(cx, project.clone()).await;
3611
3612 // Add a buffer to the context. This will be a tracked buffer
3613 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3614 .await
3615 .unwrap();
3616
3617 let context = context_store
3618 .read_with(cx, |store, _| store.context().next().cloned())
3619 .unwrap();
3620 let loaded_context = cx
3621 .update(|cx| load_context(vec![context], &project, &None, cx))
3622 .await;
3623
3624 // Insert user message and assistant response
3625 thread.update(cx, |thread, cx| {
3626 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3627 thread.insert_assistant_message(
3628 vec![MessageSegment::Text("This code prints 42.".into())],
3629 cx,
3630 );
3631 });
3632
3633 // We shouldn't have a stale buffer notification yet
3634 let notifications = thread.read_with(cx, |thread, _| {
3635 find_tool_uses(thread, "project_notifications")
3636 });
3637 assert!(
3638 notifications.is_empty(),
3639 "Should not have stale buffer notification before buffer is modified"
3640 );
3641
3642 // Modify the buffer
3643 buffer.update(cx, |buffer, cx| {
3644 buffer.edit(
3645 [(1..1, "\n println!(\"Added a new line\");\n")],
3646 None,
3647 cx,
3648 );
3649 });
3650
3651 // Insert another user message
3652 thread.update(cx, |thread, cx| {
3653 thread.insert_user_message(
3654 "What does the code do now?",
3655 ContextLoadResult::default(),
3656 None,
3657 Vec::new(),
3658 cx,
3659 )
3660 });
3661
3662 // Check for the stale buffer warning
3663 thread.update(cx, |thread, cx| {
3664 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3665 });
3666
3667 let notifications = thread.read_with(cx, |thread, _cx| {
3668 find_tool_uses(thread, "project_notifications")
3669 });
3670
3671 let [notification] = notifications.as_slice() else {
3672 panic!("Should have a `project_notifications` tool use");
3673 };
3674
3675 let Some(notification_content) = notification.content.to_str() else {
3676 panic!("`project_notifications` should return text");
3677 };
3678
3679 let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3680
3681 These files have changed since the last read:
3682 - code.rs
3683 "};
3684 assert_eq!(notification_content, expected_content);
3685
3686 // Insert another user message and flush notifications again
3687 thread.update(cx, |thread, cx| {
3688 thread.insert_user_message(
3689 "Can you tell me more?",
3690 ContextLoadResult::default(),
3691 None,
3692 Vec::new(),
3693 cx,
3694 )
3695 });
3696
3697 thread.update(cx, |thread, cx| {
3698 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3699 });
3700
3701 // There should be no new notifications (we already flushed one)
3702 let notifications = thread.read_with(cx, |thread, _cx| {
3703 find_tool_uses(thread, "project_notifications")
3704 });
3705
3706 assert_eq!(
3707 notifications.len(),
3708 1,
3709 "Should still have only one notification after second flush - no duplicates"
3710 );
3711 }
3712
3713 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3714 thread
3715 .messages()
3716 .flat_map(|message| {
3717 thread
3718 .tool_results_for_message(message.id)
3719 .into_iter()
3720 .filter(|result| result.tool_name == tool_name.into())
3721 .cloned()
3722 .collect::<Vec<_>>()
3723 })
3724 .collect()
3725 }
3726
3727 #[gpui::test]
3728 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3729 init_test_settings(cx);
3730
3731 let project = create_test_project(
3732 cx,
3733 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3734 )
3735 .await;
3736
3737 let (_workspace, thread_store, thread, _context_store, _model) =
3738 setup_test_environment(cx, project.clone()).await;
3739
3740 // Check that we are starting with the default profile
3741 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3742 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3743 assert_eq!(
3744 profile,
3745 AgentProfile::new(AgentProfileId::default(), tool_set)
3746 );
3747 }
3748
3749 #[gpui::test]
3750 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3751 init_test_settings(cx);
3752
3753 let project = create_test_project(
3754 cx,
3755 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3756 )
3757 .await;
3758
3759 let (_workspace, thread_store, thread, _context_store, _model) =
3760 setup_test_environment(cx, project.clone()).await;
3761
3762 // Profile gets serialized with default values
3763 let serialized = thread
3764 .update(cx, |thread, cx| thread.serialize(cx))
3765 .await
3766 .unwrap();
3767
3768 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3769
3770 let deserialized = cx.update(|cx| {
3771 thread.update(cx, |thread, cx| {
3772 Thread::deserialize(
3773 thread.id.clone(),
3774 serialized,
3775 thread.project.clone(),
3776 thread.tools.clone(),
3777 thread.prompt_builder.clone(),
3778 thread.project_context.clone(),
3779 None,
3780 cx,
3781 )
3782 })
3783 });
3784 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3785
3786 assert_eq!(
3787 deserialized.profile,
3788 AgentProfile::new(AgentProfileId::default(), tool_set)
3789 );
3790 }
3791
3792 #[gpui::test]
3793 async fn test_temperature_setting(cx: &mut TestAppContext) {
3794 init_test_settings(cx);
3795
3796 let project = create_test_project(
3797 cx,
3798 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3799 )
3800 .await;
3801
3802 let (_workspace, _thread_store, thread, _context_store, model) =
3803 setup_test_environment(cx, project.clone()).await;
3804
3805 // Both model and provider
3806 cx.update(|cx| {
3807 AgentSettings::override_global(
3808 AgentSettings {
3809 model_parameters: vec![LanguageModelParameters {
3810 provider: Some(model.provider_id().0.to_string().into()),
3811 model: Some(model.id().0.clone()),
3812 temperature: Some(0.66),
3813 }],
3814 ..AgentSettings::get_global(cx).clone()
3815 },
3816 cx,
3817 );
3818 });
3819
3820 let request = thread.update(cx, |thread, cx| {
3821 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3822 });
3823 assert_eq!(request.temperature, Some(0.66));
3824
3825 // Only model
3826 cx.update(|cx| {
3827 AgentSettings::override_global(
3828 AgentSettings {
3829 model_parameters: vec![LanguageModelParameters {
3830 provider: None,
3831 model: Some(model.id().0.clone()),
3832 temperature: Some(0.66),
3833 }],
3834 ..AgentSettings::get_global(cx).clone()
3835 },
3836 cx,
3837 );
3838 });
3839
3840 let request = thread.update(cx, |thread, cx| {
3841 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3842 });
3843 assert_eq!(request.temperature, Some(0.66));
3844
3845 // Only provider
3846 cx.update(|cx| {
3847 AgentSettings::override_global(
3848 AgentSettings {
3849 model_parameters: vec![LanguageModelParameters {
3850 provider: Some(model.provider_id().0.to_string().into()),
3851 model: None,
3852 temperature: Some(0.66),
3853 }],
3854 ..AgentSettings::get_global(cx).clone()
3855 },
3856 cx,
3857 );
3858 });
3859
3860 let request = thread.update(cx, |thread, cx| {
3861 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3862 });
3863 assert_eq!(request.temperature, Some(0.66));
3864
3865 // Same model name, different provider
3866 cx.update(|cx| {
3867 AgentSettings::override_global(
3868 AgentSettings {
3869 model_parameters: vec![LanguageModelParameters {
3870 provider: Some("anthropic".into()),
3871 model: Some(model.id().0.clone()),
3872 temperature: Some(0.66),
3873 }],
3874 ..AgentSettings::get_global(cx).clone()
3875 },
3876 cx,
3877 );
3878 });
3879
3880 let request = thread.update(cx, |thread, cx| {
3881 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3882 });
3883 assert_eq!(request.temperature, None);
3884 }
3885
3886 #[gpui::test]
3887 async fn test_thread_summary(cx: &mut TestAppContext) {
3888 init_test_settings(cx);
3889
3890 let project = create_test_project(cx, json!({})).await;
3891
3892 let (_, _thread_store, thread, _context_store, model) =
3893 setup_test_environment(cx, project.clone()).await;
3894
3895 // Initial state should be pending
3896 thread.read_with(cx, |thread, _| {
3897 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3898 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3899 });
3900
3901 // Manually setting the summary should not be allowed in this state
3902 thread.update(cx, |thread, cx| {
3903 thread.set_summary("This should not work", cx);
3904 });
3905
3906 thread.read_with(cx, |thread, _| {
3907 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3908 });
3909
3910 // Send a message
3911 thread.update(cx, |thread, cx| {
3912 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3913 thread.send_to_model(
3914 model.clone(),
3915 CompletionIntent::ThreadSummarization,
3916 None,
3917 cx,
3918 );
3919 });
3920
3921 let fake_model = model.as_fake();
3922 simulate_successful_response(&fake_model, cx);
3923
3924 // Should start generating summary when there are >= 2 messages
3925 thread.read_with(cx, |thread, _| {
3926 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3927 });
3928
3929 // Should not be able to set the summary while generating
3930 thread.update(cx, |thread, cx| {
3931 thread.set_summary("This should not work either", cx);
3932 });
3933
3934 thread.read_with(cx, |thread, _| {
3935 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3936 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3937 });
3938
3939 cx.run_until_parked();
3940 fake_model.stream_last_completion_response("Brief");
3941 fake_model.stream_last_completion_response(" Introduction");
3942 fake_model.end_last_completion_stream();
3943 cx.run_until_parked();
3944
3945 // Summary should be set
3946 thread.read_with(cx, |thread, _| {
3947 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3948 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3949 });
3950
3951 // Now we should be able to set a summary
3952 thread.update(cx, |thread, cx| {
3953 thread.set_summary("Brief Intro", cx);
3954 });
3955
3956 thread.read_with(cx, |thread, _| {
3957 assert_eq!(thread.summary().or_default(), "Brief Intro");
3958 });
3959
3960 // Test setting an empty summary (should default to DEFAULT)
3961 thread.update(cx, |thread, cx| {
3962 thread.set_summary("", cx);
3963 });
3964
3965 thread.read_with(cx, |thread, _| {
3966 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3967 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3968 });
3969 }
3970
3971 #[gpui::test]
3972 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3973 init_test_settings(cx);
3974
3975 let project = create_test_project(cx, json!({})).await;
3976
3977 let (_, _thread_store, thread, _context_store, model) =
3978 setup_test_environment(cx, project.clone()).await;
3979
3980 test_summarize_error(&model, &thread, cx);
3981
3982 // Now we should be able to set a summary
3983 thread.update(cx, |thread, cx| {
3984 thread.set_summary("Brief Intro", cx);
3985 });
3986
3987 thread.read_with(cx, |thread, _| {
3988 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3989 assert_eq!(thread.summary().or_default(), "Brief Intro");
3990 });
3991 }
3992
3993 #[gpui::test]
3994 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3995 init_test_settings(cx);
3996
3997 let project = create_test_project(cx, json!({})).await;
3998
3999 let (_, _thread_store, thread, _context_store, model) =
4000 setup_test_environment(cx, project.clone()).await;
4001
4002 test_summarize_error(&model, &thread, cx);
4003
4004 // Sending another message should not trigger another summarize request
4005 thread.update(cx, |thread, cx| {
4006 thread.insert_user_message(
4007 "How are you?",
4008 ContextLoadResult::default(),
4009 None,
4010 vec![],
4011 cx,
4012 );
4013 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4014 });
4015
4016 let fake_model = model.as_fake();
4017 simulate_successful_response(&fake_model, cx);
4018
4019 thread.read_with(cx, |thread, _| {
4020 // State is still Error, not Generating
4021 assert!(matches!(thread.summary(), ThreadSummary::Error));
4022 });
4023
4024 // But the summarize request can be invoked manually
4025 thread.update(cx, |thread, cx| {
4026 thread.summarize(cx);
4027 });
4028
4029 thread.read_with(cx, |thread, _| {
4030 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4031 });
4032
4033 cx.run_until_parked();
4034 fake_model.stream_last_completion_response("A successful summary");
4035 fake_model.end_last_completion_stream();
4036 cx.run_until_parked();
4037
4038 thread.read_with(cx, |thread, _| {
4039 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4040 assert_eq!(thread.summary().or_default(), "A successful summary");
4041 });
4042 }
4043
4044 // Helper to create a model that returns errors
4045 enum TestError {
4046 Overloaded,
4047 InternalServerError,
4048 }
4049
4050 struct ErrorInjector {
4051 inner: Arc<FakeLanguageModel>,
4052 error_type: TestError,
4053 }
4054
4055 impl ErrorInjector {
4056 fn new(error_type: TestError) -> Self {
4057 Self {
4058 inner: Arc::new(FakeLanguageModel::default()),
4059 error_type,
4060 }
4061 }
4062 }
4063
4064 impl LanguageModel for ErrorInjector {
4065 fn id(&self) -> LanguageModelId {
4066 self.inner.id()
4067 }
4068
4069 fn name(&self) -> LanguageModelName {
4070 self.inner.name()
4071 }
4072
4073 fn provider_id(&self) -> LanguageModelProviderId {
4074 self.inner.provider_id()
4075 }
4076
4077 fn provider_name(&self) -> LanguageModelProviderName {
4078 self.inner.provider_name()
4079 }
4080
4081 fn supports_tools(&self) -> bool {
4082 self.inner.supports_tools()
4083 }
4084
4085 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4086 self.inner.supports_tool_choice(choice)
4087 }
4088
4089 fn supports_images(&self) -> bool {
4090 self.inner.supports_images()
4091 }
4092
4093 fn telemetry_id(&self) -> String {
4094 self.inner.telemetry_id()
4095 }
4096
4097 fn max_token_count(&self) -> u64 {
4098 self.inner.max_token_count()
4099 }
4100
4101 fn count_tokens(
4102 &self,
4103 request: LanguageModelRequest,
4104 cx: &App,
4105 ) -> BoxFuture<'static, Result<u64>> {
4106 self.inner.count_tokens(request, cx)
4107 }
4108
4109 fn stream_completion(
4110 &self,
4111 _request: LanguageModelRequest,
4112 _cx: &AsyncApp,
4113 ) -> BoxFuture<
4114 'static,
4115 Result<
4116 BoxStream<
4117 'static,
4118 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4119 >,
4120 LanguageModelCompletionError,
4121 >,
4122 > {
4123 let error = match self.error_type {
4124 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4125 provider: self.provider_name(),
4126 retry_after: None,
4127 },
4128 TestError::InternalServerError => {
4129 LanguageModelCompletionError::ApiInternalServerError {
4130 provider: self.provider_name(),
4131 message: "I'm a teapot orbiting the sun".to_string(),
4132 }
4133 }
4134 };
4135 async move {
4136 let stream = futures::stream::once(async move { Err(error) });
4137 Ok(stream.boxed())
4138 }
4139 .boxed()
4140 }
4141
4142 fn as_fake(&self) -> &FakeLanguageModel {
4143 &self.inner
4144 }
4145 }
4146
4147 #[gpui::test]
4148 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4149 init_test_settings(cx);
4150
4151 let project = create_test_project(cx, json!({})).await;
4152 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4153
4154 // Create model that returns overloaded error
4155 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4156
4157 // Insert a user message
4158 thread.update(cx, |thread, cx| {
4159 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4160 });
4161
4162 // Start completion
4163 thread.update(cx, |thread, cx| {
4164 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4165 });
4166
4167 cx.run_until_parked();
4168
4169 thread.read_with(cx, |thread, _| {
4170 assert!(thread.retry_state.is_some(), "Should have retry state");
4171 let retry_state = thread.retry_state.as_ref().unwrap();
4172 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4173 assert_eq!(
4174 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4175 "Should have default max attempts"
4176 );
4177 });
4178
4179 // Check that a retry message was added
4180 thread.read_with(cx, |thread, _| {
4181 let mut messages = thread.messages();
4182 assert!(
4183 messages.any(|msg| {
4184 msg.role == Role::System
4185 && msg.ui_only
4186 && msg.segments.iter().any(|seg| {
4187 if let MessageSegment::Text(text) = seg {
4188 text.contains("overloaded")
4189 && text
4190 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4191 } else {
4192 false
4193 }
4194 })
4195 }),
4196 "Should have added a system retry message"
4197 );
4198 });
4199
4200 let retry_count = thread.update(cx, |thread, _| {
4201 thread
4202 .messages
4203 .iter()
4204 .filter(|m| {
4205 m.ui_only
4206 && m.segments.iter().any(|s| {
4207 if let MessageSegment::Text(text) = s {
4208 text.contains("Retrying") && text.contains("seconds")
4209 } else {
4210 false
4211 }
4212 })
4213 })
4214 .count()
4215 });
4216
4217 assert_eq!(retry_count, 1, "Should have one retry message");
4218 }
4219
4220 #[gpui::test]
4221 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4222 init_test_settings(cx);
4223
4224 let project = create_test_project(cx, json!({})).await;
4225 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4226
4227 // Create model that returns internal server error
4228 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4229
4230 // Insert a user message
4231 thread.update(cx, |thread, cx| {
4232 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4233 });
4234
4235 // Start completion
4236 thread.update(cx, |thread, cx| {
4237 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4238 });
4239
4240 cx.run_until_parked();
4241
4242 // Check retry state on thread
4243 thread.read_with(cx, |thread, _| {
4244 assert!(thread.retry_state.is_some(), "Should have retry state");
4245 let retry_state = thread.retry_state.as_ref().unwrap();
4246 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4247 assert_eq!(
4248 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4249 "Should have correct max attempts"
4250 );
4251 });
4252
4253 // Check that a retry message was added with provider name
4254 thread.read_with(cx, |thread, _| {
4255 let mut messages = thread.messages();
4256 assert!(
4257 messages.any(|msg| {
4258 msg.role == Role::System
4259 && msg.ui_only
4260 && msg.segments.iter().any(|seg| {
4261 if let MessageSegment::Text(text) = seg {
4262 text.contains("internal")
4263 && text.contains("Fake")
4264 && text
4265 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4266 } else {
4267 false
4268 }
4269 })
4270 }),
4271 "Should have added a system retry message with provider name"
4272 );
4273 });
4274
4275 // Count retry messages
4276 let retry_count = thread.update(cx, |thread, _| {
4277 thread
4278 .messages
4279 .iter()
4280 .filter(|m| {
4281 m.ui_only
4282 && m.segments.iter().any(|s| {
4283 if let MessageSegment::Text(text) = s {
4284 text.contains("Retrying") && text.contains("seconds")
4285 } else {
4286 false
4287 }
4288 })
4289 })
4290 .count()
4291 });
4292
4293 assert_eq!(retry_count, 1, "Should have one retry message");
4294 }
4295
4296 #[gpui::test]
4297 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4298 init_test_settings(cx);
4299
4300 let project = create_test_project(cx, json!({})).await;
4301 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4302
4303 // Create model that returns overloaded error
4304 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4305
4306 // Insert a user message
4307 thread.update(cx, |thread, cx| {
4308 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4309 });
4310
4311 // Track retry events and completion count
4312 // Track completion events
4313 let completion_count = Arc::new(Mutex::new(0));
4314 let completion_count_clone = completion_count.clone();
4315
4316 let _subscription = thread.update(cx, |_, cx| {
4317 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4318 if let ThreadEvent::NewRequest = event {
4319 *completion_count_clone.lock() += 1;
4320 }
4321 })
4322 });
4323
4324 // First attempt
4325 thread.update(cx, |thread, cx| {
4326 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4327 });
4328 cx.run_until_parked();
4329
4330 // Should have scheduled first retry - count retry messages
4331 let retry_count = thread.update(cx, |thread, _| {
4332 thread
4333 .messages
4334 .iter()
4335 .filter(|m| {
4336 m.ui_only
4337 && m.segments.iter().any(|s| {
4338 if let MessageSegment::Text(text) = s {
4339 text.contains("Retrying") && text.contains("seconds")
4340 } else {
4341 false
4342 }
4343 })
4344 })
4345 .count()
4346 });
4347 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4348
4349 // Check retry state
4350 thread.read_with(cx, |thread, _| {
4351 assert!(thread.retry_state.is_some(), "Should have retry state");
4352 let retry_state = thread.retry_state.as_ref().unwrap();
4353 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4354 });
4355
4356 // Advance clock for first retry
4357 cx.executor()
4358 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4359 cx.run_until_parked();
4360
4361 // Should have scheduled second retry - count retry messages
4362 let retry_count = thread.update(cx, |thread, _| {
4363 thread
4364 .messages
4365 .iter()
4366 .filter(|m| {
4367 m.ui_only
4368 && m.segments.iter().any(|s| {
4369 if let MessageSegment::Text(text) = s {
4370 text.contains("Retrying") && text.contains("seconds")
4371 } else {
4372 false
4373 }
4374 })
4375 })
4376 .count()
4377 });
4378 assert_eq!(retry_count, 2, "Should have scheduled second retry");
4379
4380 // Check retry state updated
4381 thread.read_with(cx, |thread, _| {
4382 assert!(thread.retry_state.is_some(), "Should have retry state");
4383 let retry_state = thread.retry_state.as_ref().unwrap();
4384 assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4385 assert_eq!(
4386 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4387 "Should have correct max attempts"
4388 );
4389 });
4390
4391 // Advance clock for second retry (exponential backoff)
4392 cx.executor()
4393 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4394 cx.run_until_parked();
4395
4396 // Should have scheduled third retry
4397 // Count all retry messages now
4398 let retry_count = thread.update(cx, |thread, _| {
4399 thread
4400 .messages
4401 .iter()
4402 .filter(|m| {
4403 m.ui_only
4404 && m.segments.iter().any(|s| {
4405 if let MessageSegment::Text(text) = s {
4406 text.contains("Retrying") && text.contains("seconds")
4407 } else {
4408 false
4409 }
4410 })
4411 })
4412 .count()
4413 });
4414 assert_eq!(
4415 retry_count, MAX_RETRY_ATTEMPTS as usize,
4416 "Should have scheduled third retry"
4417 );
4418
4419 // Check retry state updated
4420 thread.read_with(cx, |thread, _| {
4421 assert!(thread.retry_state.is_some(), "Should have retry state");
4422 let retry_state = thread.retry_state.as_ref().unwrap();
4423 assert_eq!(
4424 retry_state.attempt, MAX_RETRY_ATTEMPTS,
4425 "Should be at max retry attempt"
4426 );
4427 assert_eq!(
4428 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4429 "Should have correct max attempts"
4430 );
4431 });
4432
4433 // Advance clock for third retry (exponential backoff)
4434 cx.executor()
4435 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4436 cx.run_until_parked();
4437
4438 // No more retries should be scheduled after clock was advanced.
4439 let retry_count = thread.update(cx, |thread, _| {
4440 thread
4441 .messages
4442 .iter()
4443 .filter(|m| {
4444 m.ui_only
4445 && m.segments.iter().any(|s| {
4446 if let MessageSegment::Text(text) = s {
4447 text.contains("Retrying") && text.contains("seconds")
4448 } else {
4449 false
4450 }
4451 })
4452 })
4453 .count()
4454 });
4455 assert_eq!(
4456 retry_count, MAX_RETRY_ATTEMPTS as usize,
4457 "Should not exceed max retries"
4458 );
4459
4460 // Final completion count should be initial + max retries
4461 assert_eq!(
4462 *completion_count.lock(),
4463 (MAX_RETRY_ATTEMPTS + 1) as usize,
4464 "Should have made initial + max retry attempts"
4465 );
4466 }
4467
4468 #[gpui::test]
4469 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4470 init_test_settings(cx);
4471
4472 let project = create_test_project(cx, json!({})).await;
4473 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4474
4475 // Create model that returns overloaded error
4476 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4477
4478 // Insert a user message
4479 thread.update(cx, |thread, cx| {
4480 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4481 });
4482
4483 // Track events
4484 let retries_failed = Arc::new(Mutex::new(false));
4485 let retries_failed_clone = retries_failed.clone();
4486
4487 let _subscription = thread.update(cx, |_, cx| {
4488 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4489 if let ThreadEvent::RetriesFailed { .. } = event {
4490 *retries_failed_clone.lock() = true;
4491 }
4492 })
4493 });
4494
4495 // Start initial completion
4496 thread.update(cx, |thread, cx| {
4497 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4498 });
4499 cx.run_until_parked();
4500
4501 // Advance through all retries
4502 for i in 0..MAX_RETRY_ATTEMPTS {
4503 let delay = if i == 0 {
4504 BASE_RETRY_DELAY_SECS
4505 } else {
4506 BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4507 };
4508 cx.executor().advance_clock(Duration::from_secs(delay));
4509 cx.run_until_parked();
4510 }
4511
4512 // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4513 // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4514 let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4515 cx.executor()
4516 .advance_clock(Duration::from_secs(final_delay));
4517 cx.run_until_parked();
4518
4519 let retry_count = thread.update(cx, |thread, _| {
4520 thread
4521 .messages
4522 .iter()
4523 .filter(|m| {
4524 m.ui_only
4525 && m.segments.iter().any(|s| {
4526 if let MessageSegment::Text(text) = s {
4527 text.contains("Retrying") && text.contains("seconds")
4528 } else {
4529 false
4530 }
4531 })
4532 })
4533 .count()
4534 });
4535
4536 // After max retries, should emit RetriesFailed event
4537 assert_eq!(
4538 retry_count, MAX_RETRY_ATTEMPTS as usize,
4539 "Should have attempted max retries"
4540 );
4541 assert!(
4542 *retries_failed.lock(),
4543 "Should emit RetriesFailed event after max retries exceeded"
4544 );
4545
4546 // Retry state should be cleared
4547 thread.read_with(cx, |thread, _| {
4548 assert!(
4549 thread.retry_state.is_none(),
4550 "Retry state should be cleared after max retries"
4551 );
4552
4553 // Verify we have the expected number of retry messages
4554 let retry_messages = thread
4555 .messages
4556 .iter()
4557 .filter(|msg| msg.ui_only && msg.role == Role::System)
4558 .count();
4559 assert_eq!(
4560 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4561 "Should have one retry message per attempt"
4562 );
4563 });
4564 }
4565
4566 #[gpui::test]
4567 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4568 init_test_settings(cx);
4569
4570 let project = create_test_project(cx, json!({})).await;
4571 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4572
4573 // We'll use a wrapper to switch behavior after first failure
4574 struct RetryTestModel {
4575 inner: Arc<FakeLanguageModel>,
4576 failed_once: Arc<Mutex<bool>>,
4577 }
4578
4579 impl LanguageModel for RetryTestModel {
4580 fn id(&self) -> LanguageModelId {
4581 self.inner.id()
4582 }
4583
4584 fn name(&self) -> LanguageModelName {
4585 self.inner.name()
4586 }
4587
4588 fn provider_id(&self) -> LanguageModelProviderId {
4589 self.inner.provider_id()
4590 }
4591
4592 fn provider_name(&self) -> LanguageModelProviderName {
4593 self.inner.provider_name()
4594 }
4595
4596 fn supports_tools(&self) -> bool {
4597 self.inner.supports_tools()
4598 }
4599
4600 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4601 self.inner.supports_tool_choice(choice)
4602 }
4603
4604 fn supports_images(&self) -> bool {
4605 self.inner.supports_images()
4606 }
4607
4608 fn telemetry_id(&self) -> String {
4609 self.inner.telemetry_id()
4610 }
4611
4612 fn max_token_count(&self) -> u64 {
4613 self.inner.max_token_count()
4614 }
4615
4616 fn count_tokens(
4617 &self,
4618 request: LanguageModelRequest,
4619 cx: &App,
4620 ) -> BoxFuture<'static, Result<u64>> {
4621 self.inner.count_tokens(request, cx)
4622 }
4623
4624 fn stream_completion(
4625 &self,
4626 request: LanguageModelRequest,
4627 cx: &AsyncApp,
4628 ) -> BoxFuture<
4629 'static,
4630 Result<
4631 BoxStream<
4632 'static,
4633 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4634 >,
4635 LanguageModelCompletionError,
4636 >,
4637 > {
4638 if !*self.failed_once.lock() {
4639 *self.failed_once.lock() = true;
4640 let provider = self.provider_name();
4641 // Return error on first attempt
4642 let stream = futures::stream::once(async move {
4643 Err(LanguageModelCompletionError::ServerOverloaded {
4644 provider,
4645 retry_after: None,
4646 })
4647 });
4648 async move { Ok(stream.boxed()) }.boxed()
4649 } else {
4650 // Succeed on retry
4651 self.inner.stream_completion(request, cx)
4652 }
4653 }
4654
4655 fn as_fake(&self) -> &FakeLanguageModel {
4656 &self.inner
4657 }
4658 }
4659
4660 let model = Arc::new(RetryTestModel {
4661 inner: Arc::new(FakeLanguageModel::default()),
4662 failed_once: Arc::new(Mutex::new(false)),
4663 });
4664
4665 // Insert a user message
4666 thread.update(cx, |thread, cx| {
4667 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4668 });
4669
4670 // Track message deletions
4671 // Track when retry completes successfully
4672 let retry_completed = Arc::new(Mutex::new(false));
4673 let retry_completed_clone = retry_completed.clone();
4674
4675 let _subscription = thread.update(cx, |_, cx| {
4676 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4677 if let ThreadEvent::StreamedCompletion = event {
4678 *retry_completed_clone.lock() = true;
4679 }
4680 })
4681 });
4682
4683 // Start completion
4684 thread.update(cx, |thread, cx| {
4685 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4686 });
4687 cx.run_until_parked();
4688
4689 // Get the retry message ID
4690 let retry_message_id = thread.read_with(cx, |thread, _| {
4691 thread
4692 .messages()
4693 .find(|msg| msg.role == Role::System && msg.ui_only)
4694 .map(|msg| msg.id)
4695 .expect("Should have a retry message")
4696 });
4697
4698 // Wait for retry
4699 cx.executor()
4700 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4701 cx.run_until_parked();
4702
4703 // Stream some successful content
4704 let fake_model = model.as_fake();
4705 // After the retry, there should be a new pending completion
4706 let pending = fake_model.pending_completions();
4707 assert!(
4708 !pending.is_empty(),
4709 "Should have a pending completion after retry"
4710 );
4711 fake_model.stream_completion_response(&pending[0], "Success!");
4712 fake_model.end_completion_stream(&pending[0]);
4713 cx.run_until_parked();
4714
4715 // Check that the retry completed successfully
4716 assert!(
4717 *retry_completed.lock(),
4718 "Retry should have completed successfully"
4719 );
4720
4721 // Retry message should still exist but be marked as ui_only
4722 thread.read_with(cx, |thread, _| {
4723 let retry_msg = thread
4724 .message(retry_message_id)
4725 .expect("Retry message should still exist");
4726 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4727 assert_eq!(
4728 retry_msg.role,
4729 Role::System,
4730 "Retry message should have System role"
4731 );
4732 });
4733 }
4734
4735 #[gpui::test]
4736 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4737 init_test_settings(cx);
4738
4739 let project = create_test_project(cx, json!({})).await;
4740 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4741
4742 // Create a model that fails once then succeeds
4743 struct FailOnceModel {
4744 inner: Arc<FakeLanguageModel>,
4745 failed_once: Arc<Mutex<bool>>,
4746 }
4747
4748 impl LanguageModel for FailOnceModel {
4749 fn id(&self) -> LanguageModelId {
4750 self.inner.id()
4751 }
4752
4753 fn name(&self) -> LanguageModelName {
4754 self.inner.name()
4755 }
4756
4757 fn provider_id(&self) -> LanguageModelProviderId {
4758 self.inner.provider_id()
4759 }
4760
4761 fn provider_name(&self) -> LanguageModelProviderName {
4762 self.inner.provider_name()
4763 }
4764
4765 fn supports_tools(&self) -> bool {
4766 self.inner.supports_tools()
4767 }
4768
4769 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4770 self.inner.supports_tool_choice(choice)
4771 }
4772
4773 fn supports_images(&self) -> bool {
4774 self.inner.supports_images()
4775 }
4776
4777 fn telemetry_id(&self) -> String {
4778 self.inner.telemetry_id()
4779 }
4780
4781 fn max_token_count(&self) -> u64 {
4782 self.inner.max_token_count()
4783 }
4784
4785 fn count_tokens(
4786 &self,
4787 request: LanguageModelRequest,
4788 cx: &App,
4789 ) -> BoxFuture<'static, Result<u64>> {
4790 self.inner.count_tokens(request, cx)
4791 }
4792
4793 fn stream_completion(
4794 &self,
4795 request: LanguageModelRequest,
4796 cx: &AsyncApp,
4797 ) -> BoxFuture<
4798 'static,
4799 Result<
4800 BoxStream<
4801 'static,
4802 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4803 >,
4804 LanguageModelCompletionError,
4805 >,
4806 > {
4807 if !*self.failed_once.lock() {
4808 *self.failed_once.lock() = true;
4809 let provider = self.provider_name();
4810 // Return error on first attempt
4811 let stream = futures::stream::once(async move {
4812 Err(LanguageModelCompletionError::ServerOverloaded {
4813 provider,
4814 retry_after: None,
4815 })
4816 });
4817 async move { Ok(stream.boxed()) }.boxed()
4818 } else {
4819 // Succeed on retry
4820 self.inner.stream_completion(request, cx)
4821 }
4822 }
4823 }
4824
4825 let fail_once_model = Arc::new(FailOnceModel {
4826 inner: Arc::new(FakeLanguageModel::default()),
4827 failed_once: Arc::new(Mutex::new(false)),
4828 });
4829
4830 // Insert a user message
4831 thread.update(cx, |thread, cx| {
4832 thread.insert_user_message(
4833 "Test message",
4834 ContextLoadResult::default(),
4835 None,
4836 vec![],
4837 cx,
4838 );
4839 });
4840
4841 // Start completion with fail-once model
4842 thread.update(cx, |thread, cx| {
4843 thread.send_to_model(
4844 fail_once_model.clone(),
4845 CompletionIntent::UserPrompt,
4846 None,
4847 cx,
4848 );
4849 });
4850
4851 cx.run_until_parked();
4852
4853 // Verify retry state exists after first failure
4854 thread.read_with(cx, |thread, _| {
4855 assert!(
4856 thread.retry_state.is_some(),
4857 "Should have retry state after failure"
4858 );
4859 });
4860
4861 // Wait for retry delay
4862 cx.executor()
4863 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4864 cx.run_until_parked();
4865
4866 // The retry should now use our FailOnceModel which should succeed
4867 // We need to help the FakeLanguageModel complete the stream
4868 let inner_fake = fail_once_model.inner.clone();
4869
4870 // Wait a bit for the retry to start
4871 cx.run_until_parked();
4872
4873 // Check for pending completions and complete them
4874 if let Some(pending) = inner_fake.pending_completions().first() {
4875 inner_fake.stream_completion_response(pending, "Success!");
4876 inner_fake.end_completion_stream(pending);
4877 }
4878 cx.run_until_parked();
4879
4880 thread.read_with(cx, |thread, _| {
4881 assert!(
4882 thread.retry_state.is_none(),
4883 "Retry state should be cleared after successful completion"
4884 );
4885
4886 let has_assistant_message = thread
4887 .messages
4888 .iter()
4889 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4890 assert!(
4891 has_assistant_message,
4892 "Should have an assistant message after successful retry"
4893 );
4894 });
4895 }
4896
4897 #[gpui::test]
4898 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4899 init_test_settings(cx);
4900
4901 let project = create_test_project(cx, json!({})).await;
4902 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4903
4904 // Create a model that returns rate limit error with retry_after
4905 struct RateLimitModel {
4906 inner: Arc<FakeLanguageModel>,
4907 }
4908
4909 impl LanguageModel for RateLimitModel {
4910 fn id(&self) -> LanguageModelId {
4911 self.inner.id()
4912 }
4913
4914 fn name(&self) -> LanguageModelName {
4915 self.inner.name()
4916 }
4917
4918 fn provider_id(&self) -> LanguageModelProviderId {
4919 self.inner.provider_id()
4920 }
4921
4922 fn provider_name(&self) -> LanguageModelProviderName {
4923 self.inner.provider_name()
4924 }
4925
4926 fn supports_tools(&self) -> bool {
4927 self.inner.supports_tools()
4928 }
4929
4930 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4931 self.inner.supports_tool_choice(choice)
4932 }
4933
4934 fn supports_images(&self) -> bool {
4935 self.inner.supports_images()
4936 }
4937
4938 fn telemetry_id(&self) -> String {
4939 self.inner.telemetry_id()
4940 }
4941
4942 fn max_token_count(&self) -> u64 {
4943 self.inner.max_token_count()
4944 }
4945
4946 fn count_tokens(
4947 &self,
4948 request: LanguageModelRequest,
4949 cx: &App,
4950 ) -> BoxFuture<'static, Result<u64>> {
4951 self.inner.count_tokens(request, cx)
4952 }
4953
4954 fn stream_completion(
4955 &self,
4956 _request: LanguageModelRequest,
4957 _cx: &AsyncApp,
4958 ) -> BoxFuture<
4959 'static,
4960 Result<
4961 BoxStream<
4962 'static,
4963 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4964 >,
4965 LanguageModelCompletionError,
4966 >,
4967 > {
4968 let provider = self.provider_name();
4969 async move {
4970 let stream = futures::stream::once(async move {
4971 Err(LanguageModelCompletionError::RateLimitExceeded {
4972 provider,
4973 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4974 })
4975 });
4976 Ok(stream.boxed())
4977 }
4978 .boxed()
4979 }
4980
4981 fn as_fake(&self) -> &FakeLanguageModel {
4982 &self.inner
4983 }
4984 }
4985
4986 let model = Arc::new(RateLimitModel {
4987 inner: Arc::new(FakeLanguageModel::default()),
4988 });
4989
4990 // Insert a user message
4991 thread.update(cx, |thread, cx| {
4992 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4993 });
4994
4995 // Start completion
4996 thread.update(cx, |thread, cx| {
4997 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4998 });
4999
5000 cx.run_until_parked();
5001
5002 let retry_count = thread.update(cx, |thread, _| {
5003 thread
5004 .messages
5005 .iter()
5006 .filter(|m| {
5007 m.ui_only
5008 && m.segments.iter().any(|s| {
5009 if let MessageSegment::Text(text) = s {
5010 text.contains("rate limit exceeded")
5011 } else {
5012 false
5013 }
5014 })
5015 })
5016 .count()
5017 });
5018 assert_eq!(retry_count, 1, "Should have scheduled one retry");
5019
5020 thread.read_with(cx, |thread, _| {
5021 assert!(
5022 thread.retry_state.is_none(),
5023 "Rate limit errors should not set retry_state"
5024 );
5025 });
5026
5027 // Verify we have one retry message
5028 thread.read_with(cx, |thread, _| {
5029 let retry_messages = thread
5030 .messages
5031 .iter()
5032 .filter(|msg| {
5033 msg.ui_only
5034 && msg.segments.iter().any(|seg| {
5035 if let MessageSegment::Text(text) = seg {
5036 text.contains("rate limit exceeded")
5037 } else {
5038 false
5039 }
5040 })
5041 })
5042 .count();
5043 assert_eq!(
5044 retry_messages, 1,
5045 "Should have one rate limit retry message"
5046 );
5047 });
5048
5049 // Check that retry message doesn't include attempt count
5050 thread.read_with(cx, |thread, _| {
5051 let retry_message = thread
5052 .messages
5053 .iter()
5054 .find(|msg| msg.role == Role::System && msg.ui_only)
5055 .expect("Should have a retry message");
5056
5057 // Check that the message doesn't contain attempt count
5058 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5059 assert!(
5060 !text.contains("attempt"),
5061 "Rate limit retry message should not contain attempt count"
5062 );
5063 assert!(
5064 text.contains(&format!(
5065 "Retrying in {} seconds",
5066 TEST_RATE_LIMIT_RETRY_SECS
5067 )),
5068 "Rate limit retry message should contain retry delay"
5069 );
5070 }
5071 });
5072 }
5073
5074 #[gpui::test]
5075 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5076 init_test_settings(cx);
5077
5078 let project = create_test_project(cx, json!({})).await;
5079 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5080
5081 // Insert a regular user message
5082 thread.update(cx, |thread, cx| {
5083 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5084 });
5085
5086 // Insert a UI-only message (like our retry notifications)
5087 thread.update(cx, |thread, cx| {
5088 let id = thread.next_message_id.post_inc();
5089 thread.messages.push(Message {
5090 id,
5091 role: Role::System,
5092 segments: vec![MessageSegment::Text(
5093 "This is a UI-only message that should not be sent to the model".to_string(),
5094 )],
5095 loaded_context: LoadedContext::default(),
5096 creases: Vec::new(),
5097 is_hidden: true,
5098 ui_only: true,
5099 });
5100 cx.emit(ThreadEvent::MessageAdded(id));
5101 });
5102
5103 // Insert another regular message
5104 thread.update(cx, |thread, cx| {
5105 thread.insert_user_message(
5106 "How are you?",
5107 ContextLoadResult::default(),
5108 None,
5109 vec![],
5110 cx,
5111 );
5112 });
5113
5114 // Generate the completion request
5115 let request = thread.update(cx, |thread, cx| {
5116 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5117 });
5118
5119 // Verify that the request only contains non-UI-only messages
5120 // Should have system prompt + 2 user messages, but not the UI-only message
5121 let user_messages: Vec<_> = request
5122 .messages
5123 .iter()
5124 .filter(|msg| msg.role == Role::User)
5125 .collect();
5126 assert_eq!(
5127 user_messages.len(),
5128 2,
5129 "Should have exactly 2 user messages"
5130 );
5131
5132 // Verify the UI-only content is not present anywhere in the request
5133 let request_text = request
5134 .messages
5135 .iter()
5136 .flat_map(|msg| &msg.content)
5137 .filter_map(|content| match content {
5138 MessageContent::Text(text) => Some(text.as_str()),
5139 _ => None,
5140 })
5141 .collect::<String>();
5142
5143 assert!(
5144 !request_text.contains("UI-only message"),
5145 "UI-only message content should not be in the request"
5146 );
5147
5148 // Verify the thread still has all 3 messages (including UI-only)
5149 thread.read_with(cx, |thread, _| {
5150 assert_eq!(
5151 thread.messages().count(),
5152 3,
5153 "Thread should have 3 messages"
5154 );
5155 assert_eq!(
5156 thread.messages().filter(|m| m.ui_only).count(),
5157 1,
5158 "Thread should have 1 UI-only message"
5159 );
5160 });
5161
5162 // Verify that UI-only messages are not serialized
5163 let serialized = thread
5164 .update(cx, |thread, cx| thread.serialize(cx))
5165 .await
5166 .unwrap();
5167 assert_eq!(
5168 serialized.messages.len(),
5169 2,
5170 "Serialized thread should only have 2 messages (no UI-only)"
5171 );
5172 }
5173
5174 #[gpui::test]
5175 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5176 init_test_settings(cx);
5177
5178 let project = create_test_project(cx, json!({})).await;
5179 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5180
5181 // Create model that returns overloaded error
5182 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5183
5184 // Insert a user message
5185 thread.update(cx, |thread, cx| {
5186 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5187 });
5188
5189 // Start completion
5190 thread.update(cx, |thread, cx| {
5191 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5192 });
5193
5194 cx.run_until_parked();
5195
5196 // Verify retry was scheduled by checking for retry message
5197 let has_retry_message = thread.read_with(cx, |thread, _| {
5198 thread.messages.iter().any(|m| {
5199 m.ui_only
5200 && m.segments.iter().any(|s| {
5201 if let MessageSegment::Text(text) = s {
5202 text.contains("Retrying") && text.contains("seconds")
5203 } else {
5204 false
5205 }
5206 })
5207 })
5208 });
5209 assert!(has_retry_message, "Should have scheduled a retry");
5210
5211 // Cancel the completion before the retry happens
5212 thread.update(cx, |thread, cx| {
5213 thread.cancel_last_completion(None, cx);
5214 });
5215
5216 cx.run_until_parked();
5217
5218 // The retry should not have happened - no pending completions
5219 let fake_model = model.as_fake();
5220 assert_eq!(
5221 fake_model.pending_completions().len(),
5222 0,
5223 "Should have no pending completions after cancellation"
5224 );
5225
5226 // Verify the retry was cancelled by checking retry state
5227 thread.read_with(cx, |thread, _| {
5228 if let Some(retry_state) = &thread.retry_state {
5229 panic!(
5230 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5231 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5232 );
5233 }
5234 });
5235 }
5236
5237 fn test_summarize_error(
5238 model: &Arc<dyn LanguageModel>,
5239 thread: &Entity<Thread>,
5240 cx: &mut TestAppContext,
5241 ) {
5242 thread.update(cx, |thread, cx| {
5243 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5244 thread.send_to_model(
5245 model.clone(),
5246 CompletionIntent::ThreadSummarization,
5247 None,
5248 cx,
5249 );
5250 });
5251
5252 let fake_model = model.as_fake();
5253 simulate_successful_response(&fake_model, cx);
5254
5255 thread.read_with(cx, |thread, _| {
5256 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5257 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5258 });
5259
5260 // Simulate summary request ending
5261 cx.run_until_parked();
5262 fake_model.end_last_completion_stream();
5263 cx.run_until_parked();
5264
5265 // State is set to Error and default message
5266 thread.read_with(cx, |thread, _| {
5267 assert!(matches!(thread.summary(), ThreadSummary::Error));
5268 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5269 });
5270 }
5271
5272 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5273 cx.run_until_parked();
5274 fake_model.stream_last_completion_response("Assistant response");
5275 fake_model.end_last_completion_stream();
5276 cx.run_until_parked();
5277 }
5278
5279 fn init_test_settings(cx: &mut TestAppContext) {
5280 cx.update(|cx| {
5281 let settings_store = SettingsStore::test(cx);
5282 cx.set_global(settings_store);
5283 language::init(cx);
5284 Project::init_settings(cx);
5285 AgentSettings::register(cx);
5286 prompt_store::init(cx);
5287 thread_store::init(cx);
5288 workspace::init_settings(cx);
5289 language_model::init_settings(cx);
5290 ThemeSettings::register(cx);
5291 ToolRegistry::default_global(cx);
5292 assistant_tool::init(cx);
5293
5294 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5295 http_client::FakeHttpClient::with_200_response(),
5296 "http://localhost".to_string(),
5297 None,
5298 ));
5299 assistant_tools::init(http_client, cx);
5300 });
5301 }
5302
5303 // Helper to create a test project with test files
5304 async fn create_test_project(
5305 cx: &mut TestAppContext,
5306 files: serde_json::Value,
5307 ) -> Entity<Project> {
5308 let fs = FakeFs::new(cx.executor());
5309 fs.insert_tree(path!("/test"), files).await;
5310 Project::test(fs, [path!("/test").as_ref()], cx).await
5311 }
5312
5313 async fn setup_test_environment(
5314 cx: &mut TestAppContext,
5315 project: Entity<Project>,
5316 ) -> (
5317 Entity<Workspace>,
5318 Entity<ThreadStore>,
5319 Entity<Thread>,
5320 Entity<ContextStore>,
5321 Arc<dyn LanguageModel>,
5322 ) {
5323 let (workspace, cx) =
5324 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5325
5326 let thread_store = cx
5327 .update(|_, cx| {
5328 ThreadStore::load(
5329 project.clone(),
5330 cx.new(|_| ToolWorkingSet::default()),
5331 None,
5332 Arc::new(PromptBuilder::new(None).unwrap()),
5333 cx,
5334 )
5335 })
5336 .await
5337 .unwrap();
5338
5339 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5340 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5341
5342 let provider = Arc::new(FakeLanguageModelProvider);
5343 let model = provider.test_model();
5344 let model: Arc<dyn LanguageModel> = Arc::new(model);
5345
5346 cx.update(|_, cx| {
5347 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5348 registry.set_default_model(
5349 Some(ConfiguredModel {
5350 provider: provider.clone(),
5351 model: model.clone(),
5352 }),
5353 cx,
5354 );
5355 registry.set_thread_summary_model(
5356 Some(ConfiguredModel {
5357 provider,
5358 model: model.clone(),
5359 }),
5360 cx,
5361 );
5362 })
5363 });
5364
5365 (workspace, thread_store, thread, context_store, model)
5366 }
5367
5368 async fn add_file_to_context(
5369 project: &Entity<Project>,
5370 context_store: &Entity<ContextStore>,
5371 path: &str,
5372 cx: &mut TestAppContext,
5373 ) -> Result<Entity<language::Buffer>> {
5374 let buffer_path = project
5375 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5376 .unwrap();
5377
5378 let buffer = project
5379 .update(cx, |project, cx| {
5380 project.open_buffer(buffer_path.clone(), cx)
5381 })
5382 .await
5383 .unwrap();
5384
5385 context_store.update(cx, |context_store, cx| {
5386 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5387 });
5388
5389 Ok(buffer)
5390 }
5391}