1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::HashMap;
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
27 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
28 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
29 PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
30};
31use postage::stream::Stream as _;
32use project::{
33 Project,
34 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
35};
36use prompt_store::{ModelContext, PromptBuilder};
37use proto::Plan;
38use schemars::JsonSchema;
39use serde::{Deserialize, Serialize};
40use settings::Settings;
41use std::{
42 io::Write,
43 ops::Range,
44 sync::Arc,
45 time::{Duration, Instant},
46};
47use thiserror::Error;
48use util::{ResultExt as _, debug_panic, post_inc};
49use uuid::Uuid;
50use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
51
52const MAX_RETRY_ATTEMPTS: u8 = 3;
53const BASE_RETRY_DELAY_SECS: u64 = 5;
54
55#[derive(
56 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
57)]
58pub struct ThreadId(Arc<str>);
59
60impl ThreadId {
61 pub fn new() -> Self {
62 Self(Uuid::new_v4().to_string().into())
63 }
64}
65
66impl std::fmt::Display for ThreadId {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "{}", self.0)
69 }
70}
71
72impl From<&str> for ThreadId {
73 fn from(value: &str) -> Self {
74 Self(value.into())
75 }
76}
77
78/// The ID of the user prompt that initiated a request.
79///
80/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
81#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
82pub struct PromptId(Arc<str>);
83
84impl PromptId {
85 pub fn new() -> Self {
86 Self(Uuid::new_v4().to_string().into())
87 }
88}
89
90impl std::fmt::Display for PromptId {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
97pub struct MessageId(pub(crate) usize);
98
99impl MessageId {
100 fn post_inc(&mut self) -> Self {
101 Self(post_inc(&mut self.0))
102 }
103
104 pub fn as_usize(&self) -> usize {
105 self.0
106 }
107}
108
109/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
110#[derive(Clone, Debug)]
111pub struct MessageCrease {
112 pub range: Range<usize>,
113 pub icon_path: SharedString,
114 pub label: SharedString,
115 /// None for a deserialized message, Some otherwise.
116 pub context: Option<AgentContextHandle>,
117}
118
119/// A message in a [`Thread`].
120#[derive(Debug, Clone)]
121pub struct Message {
122 pub id: MessageId,
123 pub role: Role,
124 pub segments: Vec<MessageSegment>,
125 pub loaded_context: LoadedContext,
126 pub creases: Vec<MessageCrease>,
127 pub is_hidden: bool,
128 pub ui_only: bool,
129}
130
131impl Message {
132 /// Returns whether the message contains any meaningful text that should be displayed
133 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
134 pub fn should_display_content(&self) -> bool {
135 self.segments.iter().all(|segment| segment.should_display())
136 }
137
138 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
139 if let Some(MessageSegment::Thinking {
140 text: segment,
141 signature: current_signature,
142 }) = self.segments.last_mut()
143 {
144 if let Some(signature) = signature {
145 *current_signature = Some(signature);
146 }
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Thinking {
150 text: text.to_string(),
151 signature,
152 });
153 }
154 }
155
156 pub fn push_redacted_thinking(&mut self, data: String) {
157 self.segments.push(MessageSegment::RedactedThinking(data));
158 }
159
160 pub fn push_text(&mut self, text: &str) {
161 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
162 segment.push_str(text);
163 } else {
164 self.segments.push(MessageSegment::Text(text.to_string()));
165 }
166 }
167
168 pub fn to_string(&self) -> String {
169 let mut result = String::new();
170
171 if !self.loaded_context.text.is_empty() {
172 result.push_str(&self.loaded_context.text);
173 }
174
175 for segment in &self.segments {
176 match segment {
177 MessageSegment::Text(text) => result.push_str(text),
178 MessageSegment::Thinking { text, .. } => {
179 result.push_str("<think>\n");
180 result.push_str(text);
181 result.push_str("\n</think>");
182 }
183 MessageSegment::RedactedThinking(_) => {}
184 }
185 }
186
187 result
188 }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum MessageSegment {
193 Text(String),
194 Thinking {
195 text: String,
196 signature: Option<String>,
197 },
198 RedactedThinking(String),
199}
200
201impl MessageSegment {
202 pub fn should_display(&self) -> bool {
203 match self {
204 Self::Text(text) => text.is_empty(),
205 Self::Thinking { text, .. } => text.is_empty(),
206 Self::RedactedThinking(_) => false,
207 }
208 }
209
210 pub fn text(&self) -> Option<&str> {
211 match self {
212 MessageSegment::Text(text) => Some(text),
213 _ => None,
214 }
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
219pub struct ProjectSnapshot {
220 pub worktree_snapshots: Vec<WorktreeSnapshot>,
221 pub unsaved_buffer_paths: Vec<String>,
222 pub timestamp: DateTime<Utc>,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
226pub struct WorktreeSnapshot {
227 pub worktree_path: String,
228 pub git_state: Option<GitState>,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
232pub struct GitState {
233 pub remote_url: Option<String>,
234 pub head_sha: Option<String>,
235 pub current_branch: Option<String>,
236 pub diff: Option<String>,
237}
238
239#[derive(Clone, Debug)]
240pub struct ThreadCheckpoint {
241 message_id: MessageId,
242 git_checkpoint: GitStoreCheckpoint,
243}
244
245#[derive(Copy, Clone, Debug, PartialEq, Eq)]
246pub enum ThreadFeedback {
247 Positive,
248 Negative,
249}
250
251pub enum LastRestoreCheckpoint {
252 Pending {
253 message_id: MessageId,
254 },
255 Error {
256 message_id: MessageId,
257 error: String,
258 },
259}
260
261impl LastRestoreCheckpoint {
262 pub fn message_id(&self) -> MessageId {
263 match self {
264 LastRestoreCheckpoint::Pending { message_id } => *message_id,
265 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
266 }
267 }
268}
269
270#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
271pub enum DetailedSummaryState {
272 #[default]
273 NotGenerated,
274 Generating {
275 message_id: MessageId,
276 },
277 Generated {
278 text: SharedString,
279 message_id: MessageId,
280 },
281}
282
283impl DetailedSummaryState {
284 fn text(&self) -> Option<SharedString> {
285 if let Self::Generated { text, .. } = self {
286 Some(text.clone())
287 } else {
288 None
289 }
290 }
291}
292
293#[derive(Default, Debug)]
294pub struct TotalTokenUsage {
295 pub total: u64,
296 pub max: u64,
297}
298
299impl TotalTokenUsage {
300 pub fn ratio(&self) -> TokenUsageRatio {
301 #[cfg(debug_assertions)]
302 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
303 .unwrap_or("0.8".to_string())
304 .parse()
305 .unwrap();
306 #[cfg(not(debug_assertions))]
307 let warning_threshold: f32 = 0.8;
308
309 // When the maximum is unknown because there is no selected model,
310 // avoid showing the token limit warning.
311 if self.max == 0 {
312 TokenUsageRatio::Normal
313 } else if self.total >= self.max {
314 TokenUsageRatio::Exceeded
315 } else if self.total as f32 / self.max as f32 >= warning_threshold {
316 TokenUsageRatio::Warning
317 } else {
318 TokenUsageRatio::Normal
319 }
320 }
321
322 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
323 TotalTokenUsage {
324 total: self.total + tokens,
325 max: self.max,
326 }
327 }
328}
329
330#[derive(Debug, Default, PartialEq, Eq)]
331pub enum TokenUsageRatio {
332 #[default]
333 Normal,
334 Warning,
335 Exceeded,
336}
337
338#[derive(Debug, Clone, Copy)]
339pub enum QueueState {
340 Sending,
341 Queued { position: usize },
342 Started,
343}
344
345/// A thread of conversation with the LLM.
346pub struct Thread {
347 id: ThreadId,
348 updated_at: DateTime<Utc>,
349 summary: ThreadSummary,
350 pending_summary: Task<Option<()>>,
351 detailed_summary_task: Task<Option<()>>,
352 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
353 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
354 completion_mode: agent_settings::CompletionMode,
355 messages: Vec<Message>,
356 next_message_id: MessageId,
357 last_prompt_id: PromptId,
358 project_context: SharedProjectContext,
359 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
360 completion_count: usize,
361 pending_completions: Vec<PendingCompletion>,
362 project: Entity<Project>,
363 prompt_builder: Arc<PromptBuilder>,
364 tools: Entity<ToolWorkingSet>,
365 tool_use: ToolUseState,
366 action_log: Entity<ActionLog>,
367 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
368 pending_checkpoint: Option<ThreadCheckpoint>,
369 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
370 request_token_usage: Vec<TokenUsage>,
371 cumulative_token_usage: TokenUsage,
372 exceeded_window_error: Option<ExceededWindowError>,
373 tool_use_limit_reached: bool,
374 feedback: Option<ThreadFeedback>,
375 retry_state: Option<RetryState>,
376 message_feedback: HashMap<MessageId, ThreadFeedback>,
377 last_auto_capture_at: Option<Instant>,
378 last_received_chunk_at: Option<Instant>,
379 request_callback: Option<
380 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
381 >,
382 remaining_turns: u32,
383 configured_model: Option<ConfiguredModel>,
384 profile: AgentProfile,
385}
386
387#[derive(Clone, Debug)]
388struct RetryState {
389 attempt: u8,
390 max_attempts: u8,
391 intent: CompletionIntent,
392}
393
394#[derive(Clone, Debug, PartialEq, Eq)]
395pub enum ThreadSummary {
396 Pending,
397 Generating,
398 Ready(SharedString),
399 Error,
400}
401
402impl ThreadSummary {
403 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
404
405 pub fn or_default(&self) -> SharedString {
406 self.unwrap_or(Self::DEFAULT)
407 }
408
409 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
410 self.ready().unwrap_or_else(|| message.into())
411 }
412
413 pub fn ready(&self) -> Option<SharedString> {
414 match self {
415 ThreadSummary::Ready(summary) => Some(summary.clone()),
416 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
417 }
418 }
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
422pub struct ExceededWindowError {
423 /// Model used when last message exceeded context window
424 model_id: LanguageModelId,
425 /// Token count including last message
426 token_count: u64,
427}
428
429impl Thread {
430 pub fn new(
431 project: Entity<Project>,
432 tools: Entity<ToolWorkingSet>,
433 prompt_builder: Arc<PromptBuilder>,
434 system_prompt: SharedProjectContext,
435 cx: &mut Context<Self>,
436 ) -> Self {
437 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
438 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
439 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
440
441 Self {
442 id: ThreadId::new(),
443 updated_at: Utc::now(),
444 summary: ThreadSummary::Pending,
445 pending_summary: Task::ready(None),
446 detailed_summary_task: Task::ready(None),
447 detailed_summary_tx,
448 detailed_summary_rx,
449 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
450 messages: Vec::new(),
451 next_message_id: MessageId(0),
452 last_prompt_id: PromptId::new(),
453 project_context: system_prompt,
454 checkpoints_by_message: HashMap::default(),
455 completion_count: 0,
456 pending_completions: Vec::new(),
457 project: project.clone(),
458 prompt_builder,
459 tools: tools.clone(),
460 last_restore_checkpoint: None,
461 pending_checkpoint: None,
462 tool_use: ToolUseState::new(tools.clone()),
463 action_log: cx.new(|_| ActionLog::new(project.clone())),
464 initial_project_snapshot: {
465 let project_snapshot = Self::project_snapshot(project, cx);
466 cx.foreground_executor()
467 .spawn(async move { Some(project_snapshot.await) })
468 .shared()
469 },
470 request_token_usage: Vec::new(),
471 cumulative_token_usage: TokenUsage::default(),
472 exceeded_window_error: None,
473 tool_use_limit_reached: false,
474 feedback: None,
475 retry_state: None,
476 message_feedback: HashMap::default(),
477 last_auto_capture_at: None,
478 last_received_chunk_at: None,
479 request_callback: None,
480 remaining_turns: u32::MAX,
481 configured_model,
482 profile: AgentProfile::new(profile_id, tools),
483 }
484 }
485
486 pub fn deserialize(
487 id: ThreadId,
488 serialized: SerializedThread,
489 project: Entity<Project>,
490 tools: Entity<ToolWorkingSet>,
491 prompt_builder: Arc<PromptBuilder>,
492 project_context: SharedProjectContext,
493 window: Option<&mut Window>, // None in headless mode
494 cx: &mut Context<Self>,
495 ) -> Self {
496 let next_message_id = MessageId(
497 serialized
498 .messages
499 .last()
500 .map(|message| message.id.0 + 1)
501 .unwrap_or(0),
502 );
503 let tool_use = ToolUseState::from_serialized_messages(
504 tools.clone(),
505 &serialized.messages,
506 project.clone(),
507 window,
508 cx,
509 );
510 let (detailed_summary_tx, detailed_summary_rx) =
511 postage::watch::channel_with(serialized.detailed_summary_state);
512
513 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
514 serialized
515 .model
516 .and_then(|model| {
517 let model = SelectedModel {
518 provider: model.provider.clone().into(),
519 model: model.model.clone().into(),
520 };
521 registry.select_model(&model, cx)
522 })
523 .or_else(|| registry.default_model())
524 });
525
526 let completion_mode = serialized
527 .completion_mode
528 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
529 let profile_id = serialized
530 .profile
531 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
532
533 Self {
534 id,
535 updated_at: serialized.updated_at,
536 summary: ThreadSummary::Ready(serialized.summary),
537 pending_summary: Task::ready(None),
538 detailed_summary_task: Task::ready(None),
539 detailed_summary_tx,
540 detailed_summary_rx,
541 completion_mode,
542 retry_state: None,
543 messages: serialized
544 .messages
545 .into_iter()
546 .map(|message| Message {
547 id: message.id,
548 role: message.role,
549 segments: message
550 .segments
551 .into_iter()
552 .map(|segment| match segment {
553 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
554 SerializedMessageSegment::Thinking { text, signature } => {
555 MessageSegment::Thinking { text, signature }
556 }
557 SerializedMessageSegment::RedactedThinking { data } => {
558 MessageSegment::RedactedThinking(data)
559 }
560 })
561 .collect(),
562 loaded_context: LoadedContext {
563 contexts: Vec::new(),
564 text: message.context,
565 images: Vec::new(),
566 },
567 creases: message
568 .creases
569 .into_iter()
570 .map(|crease| MessageCrease {
571 range: crease.start..crease.end,
572 icon_path: crease.icon_path,
573 label: crease.label,
574 context: None,
575 })
576 .collect(),
577 is_hidden: message.is_hidden,
578 ui_only: false, // UI-only messages are not persisted
579 })
580 .collect(),
581 next_message_id,
582 last_prompt_id: PromptId::new(),
583 project_context,
584 checkpoints_by_message: HashMap::default(),
585 completion_count: 0,
586 pending_completions: Vec::new(),
587 last_restore_checkpoint: None,
588 pending_checkpoint: None,
589 project: project.clone(),
590 prompt_builder,
591 tools: tools.clone(),
592 tool_use,
593 action_log: cx.new(|_| ActionLog::new(project)),
594 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
595 request_token_usage: serialized.request_token_usage,
596 cumulative_token_usage: serialized.cumulative_token_usage,
597 exceeded_window_error: None,
598 tool_use_limit_reached: serialized.tool_use_limit_reached,
599 feedback: None,
600 message_feedback: HashMap::default(),
601 last_auto_capture_at: None,
602 last_received_chunk_at: None,
603 request_callback: None,
604 remaining_turns: u32::MAX,
605 configured_model,
606 profile: AgentProfile::new(profile_id, tools),
607 }
608 }
609
610 pub fn set_request_callback(
611 &mut self,
612 callback: impl 'static
613 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
614 ) {
615 self.request_callback = Some(Box::new(callback));
616 }
617
618 pub fn id(&self) -> &ThreadId {
619 &self.id
620 }
621
622 pub fn profile(&self) -> &AgentProfile {
623 &self.profile
624 }
625
626 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
627 if &id != self.profile.id() {
628 self.profile = AgentProfile::new(id, self.tools.clone());
629 cx.emit(ThreadEvent::ProfileChanged);
630 }
631 }
632
633 pub fn is_empty(&self) -> bool {
634 self.messages.is_empty()
635 }
636
637 pub fn updated_at(&self) -> DateTime<Utc> {
638 self.updated_at
639 }
640
641 pub fn touch_updated_at(&mut self) {
642 self.updated_at = Utc::now();
643 }
644
645 pub fn advance_prompt_id(&mut self) {
646 self.last_prompt_id = PromptId::new();
647 }
648
649 pub fn project_context(&self) -> SharedProjectContext {
650 self.project_context.clone()
651 }
652
653 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
654 if self.configured_model.is_none() {
655 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
656 }
657 self.configured_model.clone()
658 }
659
660 pub fn configured_model(&self) -> Option<ConfiguredModel> {
661 self.configured_model.clone()
662 }
663
664 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
665 self.configured_model = model;
666 cx.notify();
667 }
668
669 pub fn summary(&self) -> &ThreadSummary {
670 &self.summary
671 }
672
673 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
674 let current_summary = match &self.summary {
675 ThreadSummary::Pending | ThreadSummary::Generating => return,
676 ThreadSummary::Ready(summary) => summary,
677 ThreadSummary::Error => &ThreadSummary::DEFAULT,
678 };
679
680 let mut new_summary = new_summary.into();
681
682 if new_summary.is_empty() {
683 new_summary = ThreadSummary::DEFAULT;
684 }
685
686 if current_summary != &new_summary {
687 self.summary = ThreadSummary::Ready(new_summary);
688 cx.emit(ThreadEvent::SummaryChanged);
689 }
690 }
691
692 pub fn completion_mode(&self) -> CompletionMode {
693 self.completion_mode
694 }
695
696 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
697 self.completion_mode = mode;
698 }
699
700 pub fn message(&self, id: MessageId) -> Option<&Message> {
701 let index = self
702 .messages
703 .binary_search_by(|message| message.id.cmp(&id))
704 .ok()?;
705
706 self.messages.get(index)
707 }
708
709 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
710 self.messages.iter()
711 }
712
713 pub fn is_generating(&self) -> bool {
714 !self.pending_completions.is_empty() || !self.all_tools_finished()
715 }
716
717 /// Indicates whether streaming of language model events is stale.
718 /// When `is_generating()` is false, this method returns `None`.
719 pub fn is_generation_stale(&self) -> Option<bool> {
720 const STALE_THRESHOLD: u128 = 250;
721
722 self.last_received_chunk_at
723 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
724 }
725
726 fn received_chunk(&mut self) {
727 self.last_received_chunk_at = Some(Instant::now());
728 }
729
730 pub fn queue_state(&self) -> Option<QueueState> {
731 self.pending_completions
732 .first()
733 .map(|pending_completion| pending_completion.queue_state)
734 }
735
736 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
737 &self.tools
738 }
739
740 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
741 self.tool_use
742 .pending_tool_uses()
743 .into_iter()
744 .find(|tool_use| &tool_use.id == id)
745 }
746
747 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
748 self.tool_use
749 .pending_tool_uses()
750 .into_iter()
751 .filter(|tool_use| tool_use.status.needs_confirmation())
752 }
753
754 pub fn has_pending_tool_uses(&self) -> bool {
755 !self.tool_use.pending_tool_uses().is_empty()
756 }
757
758 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
759 self.checkpoints_by_message.get(&id).cloned()
760 }
761
762 pub fn restore_checkpoint(
763 &mut self,
764 checkpoint: ThreadCheckpoint,
765 cx: &mut Context<Self>,
766 ) -> Task<Result<()>> {
767 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
768 message_id: checkpoint.message_id,
769 });
770 cx.emit(ThreadEvent::CheckpointChanged);
771 cx.notify();
772
773 let git_store = self.project().read(cx).git_store().clone();
774 let restore = git_store.update(cx, |git_store, cx| {
775 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
776 });
777
778 cx.spawn(async move |this, cx| {
779 let result = restore.await;
780 this.update(cx, |this, cx| {
781 if let Err(err) = result.as_ref() {
782 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
783 message_id: checkpoint.message_id,
784 error: err.to_string(),
785 });
786 } else {
787 this.truncate(checkpoint.message_id, cx);
788 this.last_restore_checkpoint = None;
789 }
790 this.pending_checkpoint = None;
791 cx.emit(ThreadEvent::CheckpointChanged);
792 cx.notify();
793 })?;
794 result
795 })
796 }
797
798 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
799 let pending_checkpoint = if self.is_generating() {
800 return;
801 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
802 checkpoint
803 } else {
804 return;
805 };
806
807 self.finalize_checkpoint(pending_checkpoint, cx);
808 }
809
810 fn finalize_checkpoint(
811 &mut self,
812 pending_checkpoint: ThreadCheckpoint,
813 cx: &mut Context<Self>,
814 ) {
815 let git_store = self.project.read(cx).git_store().clone();
816 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
817 cx.spawn(async move |this, cx| match final_checkpoint.await {
818 Ok(final_checkpoint) => {
819 let equal = git_store
820 .update(cx, |store, cx| {
821 store.compare_checkpoints(
822 pending_checkpoint.git_checkpoint.clone(),
823 final_checkpoint.clone(),
824 cx,
825 )
826 })?
827 .await
828 .unwrap_or(false);
829
830 if !equal {
831 this.update(cx, |this, cx| {
832 this.insert_checkpoint(pending_checkpoint, cx)
833 })?;
834 }
835
836 Ok(())
837 }
838 Err(_) => this.update(cx, |this, cx| {
839 this.insert_checkpoint(pending_checkpoint, cx)
840 }),
841 })
842 .detach();
843 }
844
845 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
846 self.checkpoints_by_message
847 .insert(checkpoint.message_id, checkpoint);
848 cx.emit(ThreadEvent::CheckpointChanged);
849 cx.notify();
850 }
851
852 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
853 self.last_restore_checkpoint.as_ref()
854 }
855
856 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
857 let Some(message_ix) = self
858 .messages
859 .iter()
860 .rposition(|message| message.id == message_id)
861 else {
862 return;
863 };
864 for deleted_message in self.messages.drain(message_ix..) {
865 self.checkpoints_by_message.remove(&deleted_message.id);
866 }
867 cx.notify();
868 }
869
870 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
871 self.messages
872 .iter()
873 .find(|message| message.id == id)
874 .into_iter()
875 .flat_map(|message| message.loaded_context.contexts.iter())
876 }
877
878 pub fn is_turn_end(&self, ix: usize) -> bool {
879 if self.messages.is_empty() {
880 return false;
881 }
882
883 if !self.is_generating() && ix == self.messages.len() - 1 {
884 return true;
885 }
886
887 let Some(message) = self.messages.get(ix) else {
888 return false;
889 };
890
891 if message.role != Role::Assistant {
892 return false;
893 }
894
895 self.messages
896 .get(ix + 1)
897 .and_then(|message| {
898 self.message(message.id)
899 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
900 })
901 .unwrap_or(false)
902 }
903
904 pub fn tool_use_limit_reached(&self) -> bool {
905 self.tool_use_limit_reached
906 }
907
908 /// Returns whether all of the tool uses have finished running.
909 pub fn all_tools_finished(&self) -> bool {
910 // If the only pending tool uses left are the ones with errors, then
911 // that means that we've finished running all of the pending tools.
912 self.tool_use
913 .pending_tool_uses()
914 .iter()
915 .all(|pending_tool_use| pending_tool_use.status.is_error())
916 }
917
918 /// Returns whether any pending tool uses may perform edits
919 pub fn has_pending_edit_tool_uses(&self) -> bool {
920 self.tool_use
921 .pending_tool_uses()
922 .iter()
923 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
924 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
925 }
926
927 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
928 self.tool_use.tool_uses_for_message(id, cx)
929 }
930
931 pub fn tool_results_for_message(
932 &self,
933 assistant_message_id: MessageId,
934 ) -> Vec<&LanguageModelToolResult> {
935 self.tool_use.tool_results_for_message(assistant_message_id)
936 }
937
938 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
939 self.tool_use.tool_result(id)
940 }
941
942 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
943 match &self.tool_use.tool_result(id)?.content {
944 LanguageModelToolResultContent::Text(text) => Some(text),
945 LanguageModelToolResultContent::Image(_) => {
946 // TODO: We should display image
947 None
948 }
949 }
950 }
951
952 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
953 self.tool_use.tool_result_card(id).cloned()
954 }
955
956 /// Return tools that are both enabled and supported by the model
957 pub fn available_tools(
958 &self,
959 cx: &App,
960 model: Arc<dyn LanguageModel>,
961 ) -> Vec<LanguageModelRequestTool> {
962 if model.supports_tools() {
963 self.profile
964 .enabled_tools(cx)
965 .into_iter()
966 .filter_map(|(name, tool)| {
967 // Skip tools that cannot be supported
968 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
969 Some(LanguageModelRequestTool {
970 name: name.into(),
971 description: tool.description(),
972 input_schema,
973 })
974 })
975 .collect()
976 } else {
977 Vec::default()
978 }
979 }
980
981 pub fn insert_user_message(
982 &mut self,
983 text: impl Into<String>,
984 loaded_context: ContextLoadResult,
985 git_checkpoint: Option<GitStoreCheckpoint>,
986 creases: Vec<MessageCrease>,
987 cx: &mut Context<Self>,
988 ) -> MessageId {
989 if !loaded_context.referenced_buffers.is_empty() {
990 self.action_log.update(cx, |log, cx| {
991 for buffer in loaded_context.referenced_buffers {
992 log.buffer_read(buffer, cx);
993 }
994 });
995 }
996
997 let message_id = self.insert_message(
998 Role::User,
999 vec![MessageSegment::Text(text.into())],
1000 loaded_context.loaded_context,
1001 creases,
1002 false,
1003 cx,
1004 );
1005
1006 if let Some(git_checkpoint) = git_checkpoint {
1007 self.pending_checkpoint = Some(ThreadCheckpoint {
1008 message_id,
1009 git_checkpoint,
1010 });
1011 }
1012
1013 self.auto_capture_telemetry(cx);
1014
1015 message_id
1016 }
1017
1018 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019 let id = self.insert_message(
1020 Role::User,
1021 vec![MessageSegment::Text("Continue where you left off".into())],
1022 LoadedContext::default(),
1023 vec![],
1024 true,
1025 cx,
1026 );
1027 self.pending_checkpoint = None;
1028
1029 id
1030 }
1031
1032 pub fn insert_assistant_message(
1033 &mut self,
1034 segments: Vec<MessageSegment>,
1035 cx: &mut Context<Self>,
1036 ) -> MessageId {
1037 self.insert_message(
1038 Role::Assistant,
1039 segments,
1040 LoadedContext::default(),
1041 Vec::new(),
1042 false,
1043 cx,
1044 )
1045 }
1046
1047 pub fn insert_message(
1048 &mut self,
1049 role: Role,
1050 segments: Vec<MessageSegment>,
1051 loaded_context: LoadedContext,
1052 creases: Vec<MessageCrease>,
1053 is_hidden: bool,
1054 cx: &mut Context<Self>,
1055 ) -> MessageId {
1056 let id = self.next_message_id.post_inc();
1057 self.messages.push(Message {
1058 id,
1059 role,
1060 segments,
1061 loaded_context,
1062 creases,
1063 is_hidden,
1064 ui_only: false,
1065 });
1066 self.touch_updated_at();
1067 cx.emit(ThreadEvent::MessageAdded(id));
1068 id
1069 }
1070
1071 pub fn edit_message(
1072 &mut self,
1073 id: MessageId,
1074 new_role: Role,
1075 new_segments: Vec<MessageSegment>,
1076 creases: Vec<MessageCrease>,
1077 loaded_context: Option<LoadedContext>,
1078 checkpoint: Option<GitStoreCheckpoint>,
1079 cx: &mut Context<Self>,
1080 ) -> bool {
1081 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082 return false;
1083 };
1084 message.role = new_role;
1085 message.segments = new_segments;
1086 message.creases = creases;
1087 if let Some(context) = loaded_context {
1088 message.loaded_context = context;
1089 }
1090 if let Some(git_checkpoint) = checkpoint {
1091 self.checkpoints_by_message.insert(
1092 id,
1093 ThreadCheckpoint {
1094 message_id: id,
1095 git_checkpoint,
1096 },
1097 );
1098 }
1099 self.touch_updated_at();
1100 cx.emit(ThreadEvent::MessageEdited(id));
1101 true
1102 }
1103
1104 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106 return false;
1107 };
1108 self.messages.remove(index);
1109 self.touch_updated_at();
1110 cx.emit(ThreadEvent::MessageDeleted(id));
1111 true
1112 }
1113
1114 /// Returns the representation of this [`Thread`] in a textual form.
1115 ///
1116 /// This is the representation we use when attaching a thread as context to another thread.
1117 pub fn text(&self) -> String {
1118 let mut text = String::new();
1119
1120 for message in &self.messages {
1121 text.push_str(match message.role {
1122 language_model::Role::User => "User:",
1123 language_model::Role::Assistant => "Agent:",
1124 language_model::Role::System => "System:",
1125 });
1126 text.push('\n');
1127
1128 for segment in &message.segments {
1129 match segment {
1130 MessageSegment::Text(content) => text.push_str(content),
1131 MessageSegment::Thinking { text: content, .. } => {
1132 text.push_str(&format!("<think>{}</think>", content))
1133 }
1134 MessageSegment::RedactedThinking(_) => {}
1135 }
1136 }
1137 text.push('\n');
1138 }
1139
1140 text
1141 }
1142
1143 /// Serializes this thread into a format for storage or telemetry.
1144 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145 let initial_project_snapshot = self.initial_project_snapshot.clone();
1146 cx.spawn(async move |this, cx| {
1147 let initial_project_snapshot = initial_project_snapshot.await;
1148 this.read_with(cx, |this, cx| SerializedThread {
1149 version: SerializedThread::VERSION.to_string(),
1150 summary: this.summary().or_default(),
1151 updated_at: this.updated_at(),
1152 messages: this
1153 .messages()
1154 .filter(|message| !message.ui_only)
1155 .map(|message| SerializedMessage {
1156 id: message.id,
1157 role: message.role,
1158 segments: message
1159 .segments
1160 .iter()
1161 .map(|segment| match segment {
1162 MessageSegment::Text(text) => {
1163 SerializedMessageSegment::Text { text: text.clone() }
1164 }
1165 MessageSegment::Thinking { text, signature } => {
1166 SerializedMessageSegment::Thinking {
1167 text: text.clone(),
1168 signature: signature.clone(),
1169 }
1170 }
1171 MessageSegment::RedactedThinking(data) => {
1172 SerializedMessageSegment::RedactedThinking {
1173 data: data.clone(),
1174 }
1175 }
1176 })
1177 .collect(),
1178 tool_uses: this
1179 .tool_uses_for_message(message.id, cx)
1180 .into_iter()
1181 .map(|tool_use| SerializedToolUse {
1182 id: tool_use.id,
1183 name: tool_use.name,
1184 input: tool_use.input,
1185 })
1186 .collect(),
1187 tool_results: this
1188 .tool_results_for_message(message.id)
1189 .into_iter()
1190 .map(|tool_result| SerializedToolResult {
1191 tool_use_id: tool_result.tool_use_id.clone(),
1192 is_error: tool_result.is_error,
1193 content: tool_result.content.clone(),
1194 output: tool_result.output.clone(),
1195 })
1196 .collect(),
1197 context: message.loaded_context.text.clone(),
1198 creases: message
1199 .creases
1200 .iter()
1201 .map(|crease| SerializedCrease {
1202 start: crease.range.start,
1203 end: crease.range.end,
1204 icon_path: crease.icon_path.clone(),
1205 label: crease.label.clone(),
1206 })
1207 .collect(),
1208 is_hidden: message.is_hidden,
1209 })
1210 .collect(),
1211 initial_project_snapshot,
1212 cumulative_token_usage: this.cumulative_token_usage,
1213 request_token_usage: this.request_token_usage.clone(),
1214 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215 exceeded_window_error: this.exceeded_window_error.clone(),
1216 model: this
1217 .configured_model
1218 .as_ref()
1219 .map(|model| SerializedLanguageModel {
1220 provider: model.provider.id().0.to_string(),
1221 model: model.model.id().0.to_string(),
1222 }),
1223 completion_mode: Some(this.completion_mode),
1224 tool_use_limit_reached: this.tool_use_limit_reached,
1225 profile: Some(this.profile.id().clone()),
1226 })
1227 })
1228 }
1229
1230 pub fn remaining_turns(&self) -> u32 {
1231 self.remaining_turns
1232 }
1233
1234 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235 self.remaining_turns = remaining_turns;
1236 }
1237
1238 pub fn send_to_model(
1239 &mut self,
1240 model: Arc<dyn LanguageModel>,
1241 intent: CompletionIntent,
1242 window: Option<AnyWindowHandle>,
1243 cx: &mut Context<Self>,
1244 ) {
1245 if self.remaining_turns == 0 {
1246 return;
1247 }
1248
1249 self.remaining_turns -= 1;
1250
1251 self.flush_notifications(model.clone(), intent, cx);
1252
1253 let request = self.to_completion_request(model.clone(), intent, cx);
1254
1255 self.stream_completion(request, model, intent, window, cx);
1256 }
1257
1258 pub fn used_tools_since_last_user_message(&self) -> bool {
1259 for message in self.messages.iter().rev() {
1260 if self.tool_use.message_has_tool_results(message.id) {
1261 return true;
1262 } else if message.role == Role::User {
1263 return false;
1264 }
1265 }
1266
1267 false
1268 }
1269
1270 pub fn to_completion_request(
1271 &self,
1272 model: Arc<dyn LanguageModel>,
1273 intent: CompletionIntent,
1274 cx: &mut Context<Self>,
1275 ) -> LanguageModelRequest {
1276 let mut request = LanguageModelRequest {
1277 thread_id: Some(self.id.to_string()),
1278 prompt_id: Some(self.last_prompt_id.to_string()),
1279 intent: Some(intent),
1280 mode: None,
1281 messages: vec![],
1282 tools: Vec::new(),
1283 tool_choice: None,
1284 stop: Vec::new(),
1285 temperature: AgentSettings::temperature_for_model(&model, cx),
1286 };
1287
1288 let available_tools = self.available_tools(cx, model.clone());
1289 let available_tool_names = available_tools
1290 .iter()
1291 .map(|tool| tool.name.clone())
1292 .collect();
1293
1294 let model_context = &ModelContext {
1295 available_tools: available_tool_names,
1296 };
1297
1298 if let Some(project_context) = self.project_context.borrow().as_ref() {
1299 match self
1300 .prompt_builder
1301 .generate_assistant_system_prompt(project_context, model_context)
1302 {
1303 Err(err) => {
1304 let message = format!("{err:?}").into();
1305 log::error!("{message}");
1306 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1307 header: "Error generating system prompt".into(),
1308 message,
1309 }));
1310 }
1311 Ok(system_prompt) => {
1312 request.messages.push(LanguageModelRequestMessage {
1313 role: Role::System,
1314 content: vec![MessageContent::Text(system_prompt)],
1315 cache: true,
1316 });
1317 }
1318 }
1319 } else {
1320 let message = "Context for system prompt unexpectedly not ready.".into();
1321 log::error!("{message}");
1322 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1323 header: "Error generating system prompt".into(),
1324 message,
1325 }));
1326 }
1327
1328 let mut message_ix_to_cache = None;
1329 for message in &self.messages {
1330 // ui_only messages are for the UI only, not for the model
1331 if message.ui_only {
1332 continue;
1333 }
1334
1335 let mut request_message = LanguageModelRequestMessage {
1336 role: message.role,
1337 content: Vec::new(),
1338 cache: false,
1339 };
1340
1341 message
1342 .loaded_context
1343 .add_to_request_message(&mut request_message);
1344
1345 for segment in &message.segments {
1346 match segment {
1347 MessageSegment::Text(text) => {
1348 let text = text.trim_end();
1349 if !text.is_empty() {
1350 request_message
1351 .content
1352 .push(MessageContent::Text(text.into()));
1353 }
1354 }
1355 MessageSegment::Thinking { text, signature } => {
1356 if !text.is_empty() {
1357 request_message.content.push(MessageContent::Thinking {
1358 text: text.into(),
1359 signature: signature.clone(),
1360 });
1361 }
1362 }
1363 MessageSegment::RedactedThinking(data) => {
1364 request_message
1365 .content
1366 .push(MessageContent::RedactedThinking(data.clone()));
1367 }
1368 };
1369 }
1370
1371 let mut cache_message = true;
1372 let mut tool_results_message = LanguageModelRequestMessage {
1373 role: Role::User,
1374 content: Vec::new(),
1375 cache: false,
1376 };
1377 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1378 if let Some(tool_result) = tool_result {
1379 request_message
1380 .content
1381 .push(MessageContent::ToolUse(tool_use.clone()));
1382 tool_results_message
1383 .content
1384 .push(MessageContent::ToolResult(LanguageModelToolResult {
1385 tool_use_id: tool_use.id.clone(),
1386 tool_name: tool_result.tool_name.clone(),
1387 is_error: tool_result.is_error,
1388 content: if tool_result.content.is_empty() {
1389 // Surprisingly, the API fails if we return an empty string here.
1390 // It thinks we are sending a tool use without a tool result.
1391 "<Tool returned an empty string>".into()
1392 } else {
1393 tool_result.content.clone()
1394 },
1395 output: None,
1396 }));
1397 } else {
1398 cache_message = false;
1399 log::debug!(
1400 "skipped tool use {:?} because it is still pending",
1401 tool_use
1402 );
1403 }
1404 }
1405
1406 if cache_message {
1407 message_ix_to_cache = Some(request.messages.len());
1408 }
1409 request.messages.push(request_message);
1410
1411 if !tool_results_message.content.is_empty() {
1412 if cache_message {
1413 message_ix_to_cache = Some(request.messages.len());
1414 }
1415 request.messages.push(tool_results_message);
1416 }
1417 }
1418
1419 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1420 if let Some(message_ix_to_cache) = message_ix_to_cache {
1421 request.messages[message_ix_to_cache].cache = true;
1422 }
1423
1424 request.tools = available_tools;
1425 request.mode = if model.supports_burn_mode() {
1426 Some(self.completion_mode.into())
1427 } else {
1428 Some(CompletionMode::Normal.into())
1429 };
1430
1431 request
1432 }
1433
1434 fn to_summarize_request(
1435 &self,
1436 model: &Arc<dyn LanguageModel>,
1437 intent: CompletionIntent,
1438 added_user_message: String,
1439 cx: &App,
1440 ) -> LanguageModelRequest {
1441 let mut request = LanguageModelRequest {
1442 thread_id: None,
1443 prompt_id: None,
1444 intent: Some(intent),
1445 mode: None,
1446 messages: vec![],
1447 tools: Vec::new(),
1448 tool_choice: None,
1449 stop: Vec::new(),
1450 temperature: AgentSettings::temperature_for_model(model, cx),
1451 };
1452
1453 for message in &self.messages {
1454 let mut request_message = LanguageModelRequestMessage {
1455 role: message.role,
1456 content: Vec::new(),
1457 cache: false,
1458 };
1459
1460 for segment in &message.segments {
1461 match segment {
1462 MessageSegment::Text(text) => request_message
1463 .content
1464 .push(MessageContent::Text(text.clone())),
1465 MessageSegment::Thinking { .. } => {}
1466 MessageSegment::RedactedThinking(_) => {}
1467 }
1468 }
1469
1470 if request_message.content.is_empty() {
1471 continue;
1472 }
1473
1474 request.messages.push(request_message);
1475 }
1476
1477 request.messages.push(LanguageModelRequestMessage {
1478 role: Role::User,
1479 content: vec![MessageContent::Text(added_user_message)],
1480 cache: false,
1481 });
1482
1483 request
1484 }
1485
1486 /// Insert auto-generated notifications (if any) to the thread
1487 fn flush_notifications(
1488 &mut self,
1489 model: Arc<dyn LanguageModel>,
1490 intent: CompletionIntent,
1491 cx: &mut Context<Self>,
1492 ) {
1493 match intent {
1494 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1495 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1496 cx.emit(ThreadEvent::ToolFinished {
1497 tool_use_id: pending_tool_use.id.clone(),
1498 pending_tool_use: Some(pending_tool_use),
1499 });
1500 }
1501 }
1502 CompletionIntent::ThreadSummarization
1503 | CompletionIntent::ThreadContextSummarization
1504 | CompletionIntent::CreateFile
1505 | CompletionIntent::EditFile
1506 | CompletionIntent::InlineAssist
1507 | CompletionIntent::TerminalInlineAssist
1508 | CompletionIntent::GenerateGitCommitMessage => {}
1509 };
1510 }
1511
1512 fn attach_tracked_files_state(
1513 &mut self,
1514 model: Arc<dyn LanguageModel>,
1515 cx: &mut App,
1516 ) -> Option<PendingToolUse> {
1517 let action_log = self.action_log.read(cx);
1518
1519 action_log.stale_buffers(cx).next()?;
1520
1521 // Represent notification as a simulated `project_notifications` tool call
1522 let tool_name = Arc::from("project_notifications");
1523 let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1524 debug_panic!("`project_notifications` tool not found");
1525 return None;
1526 };
1527
1528 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1529 return None;
1530 }
1531
1532 let input = serde_json::json!({});
1533 let request = Arc::new(LanguageModelRequest::default()); // unused
1534 let window = None;
1535 let tool_result = tool.run(
1536 input,
1537 request,
1538 self.project.clone(),
1539 self.action_log.clone(),
1540 model.clone(),
1541 window,
1542 cx,
1543 );
1544
1545 let tool_use_id =
1546 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1547
1548 let tool_use = LanguageModelToolUse {
1549 id: tool_use_id.clone(),
1550 name: tool_name.clone(),
1551 raw_input: "{}".to_string(),
1552 input: serde_json::json!({}),
1553 is_input_complete: true,
1554 };
1555
1556 let tool_output = cx.background_executor().block(tool_result.output);
1557
1558 // Attach a project_notification tool call to the latest existing
1559 // Assistant message. We cannot create a new Assistant message
1560 // because thinking models require a `thinking` block that we
1561 // cannot mock. We cannot send a notification as a normal
1562 // (non-tool-use) User message because this distracts Agent
1563 // too much.
1564 let tool_message_id = self
1565 .messages
1566 .iter()
1567 .enumerate()
1568 .rfind(|(_, message)| message.role == Role::Assistant)
1569 .map(|(_, message)| message.id)?;
1570
1571 let tool_use_metadata = ToolUseMetadata {
1572 model: model.clone(),
1573 thread_id: self.id.clone(),
1574 prompt_id: self.last_prompt_id.clone(),
1575 };
1576
1577 self.tool_use
1578 .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1579
1580 let pending_tool_use = self.tool_use.insert_tool_output(
1581 tool_use_id.clone(),
1582 tool_name,
1583 tool_output,
1584 self.configured_model.as_ref(),
1585 );
1586
1587 pending_tool_use
1588 }
1589
1590 pub fn stream_completion(
1591 &mut self,
1592 request: LanguageModelRequest,
1593 model: Arc<dyn LanguageModel>,
1594 intent: CompletionIntent,
1595 window: Option<AnyWindowHandle>,
1596 cx: &mut Context<Self>,
1597 ) {
1598 self.tool_use_limit_reached = false;
1599
1600 let pending_completion_id = post_inc(&mut self.completion_count);
1601 let mut request_callback_parameters = if self.request_callback.is_some() {
1602 Some((request.clone(), Vec::new()))
1603 } else {
1604 None
1605 };
1606 let prompt_id = self.last_prompt_id.clone();
1607 let tool_use_metadata = ToolUseMetadata {
1608 model: model.clone(),
1609 thread_id: self.id.clone(),
1610 prompt_id: prompt_id.clone(),
1611 };
1612
1613 self.last_received_chunk_at = Some(Instant::now());
1614
1615 let task = cx.spawn(async move |thread, cx| {
1616 let stream_completion_future = model.stream_completion(request, &cx);
1617 let initial_token_usage =
1618 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1619 let stream_completion = async {
1620 let mut events = stream_completion_future.await?;
1621
1622 let mut stop_reason = StopReason::EndTurn;
1623 let mut current_token_usage = TokenUsage::default();
1624
1625 thread
1626 .update(cx, |_thread, cx| {
1627 cx.emit(ThreadEvent::NewRequest);
1628 })
1629 .ok();
1630
1631 let mut request_assistant_message_id = None;
1632
1633 while let Some(event) = events.next().await {
1634 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1635 response_events
1636 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1637 }
1638
1639 thread.update(cx, |thread, cx| {
1640 match event? {
1641 LanguageModelCompletionEvent::StartMessage { .. } => {
1642 request_assistant_message_id =
1643 Some(thread.insert_assistant_message(
1644 vec![MessageSegment::Text(String::new())],
1645 cx,
1646 ));
1647 }
1648 LanguageModelCompletionEvent::Stop(reason) => {
1649 stop_reason = reason;
1650 }
1651 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1652 thread.update_token_usage_at_last_message(token_usage);
1653 thread.cumulative_token_usage = thread.cumulative_token_usage
1654 + token_usage
1655 - current_token_usage;
1656 current_token_usage = token_usage;
1657 }
1658 LanguageModelCompletionEvent::Text(chunk) => {
1659 thread.received_chunk();
1660
1661 cx.emit(ThreadEvent::ReceivedTextChunk);
1662 if let Some(last_message) = thread.messages.last_mut() {
1663 if last_message.role == Role::Assistant
1664 && !thread.tool_use.has_tool_results(last_message.id)
1665 {
1666 last_message.push_text(&chunk);
1667 cx.emit(ThreadEvent::StreamedAssistantText(
1668 last_message.id,
1669 chunk,
1670 ));
1671 } else {
1672 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1673 // of a new Assistant response.
1674 //
1675 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1676 // will result in duplicating the text of the chunk in the rendered Markdown.
1677 request_assistant_message_id =
1678 Some(thread.insert_assistant_message(
1679 vec![MessageSegment::Text(chunk.to_string())],
1680 cx,
1681 ));
1682 };
1683 }
1684 }
1685 LanguageModelCompletionEvent::Thinking {
1686 text: chunk,
1687 signature,
1688 } => {
1689 thread.received_chunk();
1690
1691 if let Some(last_message) = thread.messages.last_mut() {
1692 if last_message.role == Role::Assistant
1693 && !thread.tool_use.has_tool_results(last_message.id)
1694 {
1695 last_message.push_thinking(&chunk, signature);
1696 cx.emit(ThreadEvent::StreamedAssistantThinking(
1697 last_message.id,
1698 chunk,
1699 ));
1700 } else {
1701 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1702 // of a new Assistant response.
1703 //
1704 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1705 // will result in duplicating the text of the chunk in the rendered Markdown.
1706 request_assistant_message_id =
1707 Some(thread.insert_assistant_message(
1708 vec![MessageSegment::Thinking {
1709 text: chunk.to_string(),
1710 signature,
1711 }],
1712 cx,
1713 ));
1714 };
1715 }
1716 }
1717 LanguageModelCompletionEvent::RedactedThinking { data } => {
1718 thread.received_chunk();
1719
1720 if let Some(last_message) = thread.messages.last_mut() {
1721 if last_message.role == Role::Assistant
1722 && !thread.tool_use.has_tool_results(last_message.id)
1723 {
1724 last_message.push_redacted_thinking(data);
1725 } else {
1726 request_assistant_message_id =
1727 Some(thread.insert_assistant_message(
1728 vec![MessageSegment::RedactedThinking(data)],
1729 cx,
1730 ));
1731 };
1732 }
1733 }
1734 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1735 let last_assistant_message_id = request_assistant_message_id
1736 .unwrap_or_else(|| {
1737 let new_assistant_message_id =
1738 thread.insert_assistant_message(vec![], cx);
1739 request_assistant_message_id =
1740 Some(new_assistant_message_id);
1741 new_assistant_message_id
1742 });
1743
1744 let tool_use_id = tool_use.id.clone();
1745 let streamed_input = if tool_use.is_input_complete {
1746 None
1747 } else {
1748 Some((&tool_use.input).clone())
1749 };
1750
1751 let ui_text = thread.tool_use.request_tool_use(
1752 last_assistant_message_id,
1753 tool_use,
1754 tool_use_metadata.clone(),
1755 cx,
1756 );
1757
1758 if let Some(input) = streamed_input {
1759 cx.emit(ThreadEvent::StreamedToolUse {
1760 tool_use_id,
1761 ui_text,
1762 input,
1763 });
1764 }
1765 }
1766 LanguageModelCompletionEvent::ToolUseJsonParseError {
1767 id,
1768 tool_name,
1769 raw_input: invalid_input_json,
1770 json_parse_error,
1771 } => {
1772 thread.receive_invalid_tool_json(
1773 id,
1774 tool_name,
1775 invalid_input_json,
1776 json_parse_error,
1777 window,
1778 cx,
1779 );
1780 }
1781 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1782 if let Some(completion) = thread
1783 .pending_completions
1784 .iter_mut()
1785 .find(|completion| completion.id == pending_completion_id)
1786 {
1787 match status_update {
1788 CompletionRequestStatus::Queued { position } => {
1789 completion.queue_state =
1790 QueueState::Queued { position };
1791 }
1792 CompletionRequestStatus::Started => {
1793 completion.queue_state = QueueState::Started;
1794 }
1795 CompletionRequestStatus::Failed {
1796 code,
1797 message,
1798 request_id: _,
1799 retry_after,
1800 } => {
1801 return Err(
1802 LanguageModelCompletionError::from_cloud_failure(
1803 model.upstream_provider_name(),
1804 code,
1805 message,
1806 retry_after.map(Duration::from_secs_f64),
1807 ),
1808 );
1809 }
1810 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1811 thread.update_model_request_usage(
1812 amount as u32,
1813 limit,
1814 cx,
1815 );
1816 }
1817 CompletionRequestStatus::ToolUseLimitReached => {
1818 thread.tool_use_limit_reached = true;
1819 cx.emit(ThreadEvent::ToolUseLimitReached);
1820 }
1821 }
1822 }
1823 }
1824 }
1825
1826 thread.touch_updated_at();
1827 cx.emit(ThreadEvent::StreamedCompletion);
1828 cx.notify();
1829
1830 thread.auto_capture_telemetry(cx);
1831 Ok(())
1832 })??;
1833
1834 smol::future::yield_now().await;
1835 }
1836
1837 thread.update(cx, |thread, cx| {
1838 thread.last_received_chunk_at = None;
1839 thread
1840 .pending_completions
1841 .retain(|completion| completion.id != pending_completion_id);
1842
1843 // If there is a response without tool use, summarize the message. Otherwise,
1844 // allow two tool uses before summarizing.
1845 if matches!(thread.summary, ThreadSummary::Pending)
1846 && thread.messages.len() >= 2
1847 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1848 {
1849 thread.summarize(cx);
1850 }
1851 })?;
1852
1853 anyhow::Ok(stop_reason)
1854 };
1855
1856 let result = stream_completion.await;
1857 let mut retry_scheduled = false;
1858
1859 thread
1860 .update(cx, |thread, cx| {
1861 thread.finalize_pending_checkpoint(cx);
1862 match result.as_ref() {
1863 Ok(stop_reason) => {
1864 match stop_reason {
1865 StopReason::ToolUse => {
1866 let tool_uses =
1867 thread.use_pending_tools(window, model.clone(), cx);
1868 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1869 }
1870 StopReason::EndTurn | StopReason::MaxTokens => {
1871 thread.project.update(cx, |project, cx| {
1872 project.set_agent_location(None, cx);
1873 });
1874 }
1875 StopReason::Refusal => {
1876 thread.project.update(cx, |project, cx| {
1877 project.set_agent_location(None, cx);
1878 });
1879
1880 // Remove the turn that was refused.
1881 //
1882 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1883 {
1884 let mut messages_to_remove = Vec::new();
1885
1886 for (ix, message) in
1887 thread.messages.iter().enumerate().rev()
1888 {
1889 messages_to_remove.push(message.id);
1890
1891 if message.role == Role::User {
1892 if ix == 0 {
1893 break;
1894 }
1895
1896 if let Some(prev_message) =
1897 thread.messages.get(ix - 1)
1898 {
1899 if prev_message.role == Role::Assistant {
1900 break;
1901 }
1902 }
1903 }
1904 }
1905
1906 for message_id in messages_to_remove {
1907 thread.delete_message(message_id, cx);
1908 }
1909 }
1910
1911 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1912 header: "Language model refusal".into(),
1913 message:
1914 "Model refused to generate content for safety reasons."
1915 .into(),
1916 }));
1917 }
1918 }
1919
1920 // We successfully completed, so cancel any remaining retries.
1921 thread.retry_state = None;
1922 }
1923 Err(error) => {
1924 thread.project.update(cx, |project, cx| {
1925 project.set_agent_location(None, cx);
1926 });
1927
1928 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1929 let error_message = error
1930 .chain()
1931 .map(|err| err.to_string())
1932 .collect::<Vec<_>>()
1933 .join("\n");
1934 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1935 header: "Error interacting with language model".into(),
1936 message: SharedString::from(error_message.clone()),
1937 }));
1938 }
1939
1940 if error.is::<PaymentRequiredError>() {
1941 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1942 } else if let Some(error) =
1943 error.downcast_ref::<ModelRequestLimitReachedError>()
1944 {
1945 cx.emit(ThreadEvent::ShowError(
1946 ThreadError::ModelRequestLimitReached { plan: error.plan },
1947 ));
1948 } else if let Some(completion_error) =
1949 error.downcast_ref::<LanguageModelCompletionError>()
1950 {
1951 use LanguageModelCompletionError::*;
1952 match &completion_error {
1953 PromptTooLarge { tokens, .. } => {
1954 let tokens = tokens.unwrap_or_else(|| {
1955 // We didn't get an exact token count from the API, so fall back on our estimate.
1956 thread
1957 .total_token_usage()
1958 .map(|usage| usage.total)
1959 .unwrap_or(0)
1960 // We know the context window was exceeded in practice, so if our estimate was
1961 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1962 .max(model.max_token_count().saturating_add(1))
1963 });
1964 thread.exceeded_window_error = Some(ExceededWindowError {
1965 model_id: model.id(),
1966 token_count: tokens,
1967 });
1968 cx.notify();
1969 }
1970 RateLimitExceeded {
1971 retry_after: Some(retry_after),
1972 ..
1973 }
1974 | ServerOverloaded {
1975 retry_after: Some(retry_after),
1976 ..
1977 } => {
1978 thread.handle_rate_limit_error(
1979 &completion_error,
1980 *retry_after,
1981 model.clone(),
1982 intent,
1983 window,
1984 cx,
1985 );
1986 retry_scheduled = true;
1987 }
1988 RateLimitExceeded { .. } | ServerOverloaded { .. } => {
1989 retry_scheduled = thread.handle_retryable_error(
1990 &completion_error,
1991 model.clone(),
1992 intent,
1993 window,
1994 cx,
1995 );
1996 if !retry_scheduled {
1997 emit_generic_error(error, cx);
1998 }
1999 }
2000 ApiInternalServerError { .. }
2001 | ApiReadResponseError { .. }
2002 | HttpSend { .. } => {
2003 retry_scheduled = thread.handle_retryable_error(
2004 &completion_error,
2005 model.clone(),
2006 intent,
2007 window,
2008 cx,
2009 );
2010 if !retry_scheduled {
2011 emit_generic_error(error, cx);
2012 }
2013 }
2014 NoApiKey { .. }
2015 | HttpResponseError { .. }
2016 | BadRequestFormat { .. }
2017 | AuthenticationError { .. }
2018 | PermissionError { .. }
2019 | ApiEndpointNotFound { .. }
2020 | SerializeRequest { .. }
2021 | BuildRequestBody { .. }
2022 | DeserializeResponse { .. }
2023 | Other { .. } => emit_generic_error(error, cx),
2024 }
2025 } else {
2026 emit_generic_error(error, cx);
2027 }
2028
2029 if !retry_scheduled {
2030 thread.cancel_last_completion(window, cx);
2031 }
2032 }
2033 }
2034
2035 if !retry_scheduled {
2036 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2037 }
2038
2039 if let Some((request_callback, (request, response_events))) = thread
2040 .request_callback
2041 .as_mut()
2042 .zip(request_callback_parameters.as_ref())
2043 {
2044 request_callback(request, response_events);
2045 }
2046
2047 thread.auto_capture_telemetry(cx);
2048
2049 if let Ok(initial_usage) = initial_token_usage {
2050 let usage = thread.cumulative_token_usage - initial_usage;
2051
2052 telemetry::event!(
2053 "Assistant Thread Completion",
2054 thread_id = thread.id().to_string(),
2055 prompt_id = prompt_id,
2056 model = model.telemetry_id(),
2057 model_provider = model.provider_id().to_string(),
2058 input_tokens = usage.input_tokens,
2059 output_tokens = usage.output_tokens,
2060 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2061 cache_read_input_tokens = usage.cache_read_input_tokens,
2062 );
2063 }
2064 })
2065 .ok();
2066 });
2067
2068 self.pending_completions.push(PendingCompletion {
2069 id: pending_completion_id,
2070 queue_state: QueueState::Sending,
2071 _task: task,
2072 });
2073 }
2074
2075 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2076 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2077 println!("No thread summary model");
2078 return;
2079 };
2080
2081 if !model.provider.is_authenticated(cx) {
2082 return;
2083 }
2084
2085 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2086
2087 let request = self.to_summarize_request(
2088 &model.model,
2089 CompletionIntent::ThreadSummarization,
2090 added_user_message.into(),
2091 cx,
2092 );
2093
2094 self.summary = ThreadSummary::Generating;
2095
2096 self.pending_summary = cx.spawn(async move |this, cx| {
2097 let result = async {
2098 let mut messages = model.model.stream_completion(request, &cx).await?;
2099
2100 let mut new_summary = String::new();
2101 while let Some(event) = messages.next().await {
2102 let Ok(event) = event else {
2103 continue;
2104 };
2105 let text = match event {
2106 LanguageModelCompletionEvent::Text(text) => text,
2107 LanguageModelCompletionEvent::StatusUpdate(
2108 CompletionRequestStatus::UsageUpdated { amount, limit },
2109 ) => {
2110 this.update(cx, |thread, cx| {
2111 thread.update_model_request_usage(amount as u32, limit, cx);
2112 })?;
2113 continue;
2114 }
2115 _ => continue,
2116 };
2117
2118 let mut lines = text.lines();
2119 new_summary.extend(lines.next());
2120
2121 // Stop if the LLM generated multiple lines.
2122 if lines.next().is_some() {
2123 break;
2124 }
2125 }
2126
2127 anyhow::Ok(new_summary)
2128 }
2129 .await;
2130
2131 this.update(cx, |this, cx| {
2132 match result {
2133 Ok(new_summary) => {
2134 if new_summary.is_empty() {
2135 this.summary = ThreadSummary::Error;
2136 } else {
2137 this.summary = ThreadSummary::Ready(new_summary.into());
2138 }
2139 }
2140 Err(err) => {
2141 this.summary = ThreadSummary::Error;
2142 log::error!("Failed to generate thread summary: {}", err);
2143 }
2144 }
2145 cx.emit(ThreadEvent::SummaryGenerated);
2146 })
2147 .log_err()?;
2148
2149 Some(())
2150 });
2151 }
2152
2153 fn handle_rate_limit_error(
2154 &mut self,
2155 error: &LanguageModelCompletionError,
2156 retry_after: Duration,
2157 model: Arc<dyn LanguageModel>,
2158 intent: CompletionIntent,
2159 window: Option<AnyWindowHandle>,
2160 cx: &mut Context<Self>,
2161 ) {
2162 // For rate limit errors, we only retry once with the specified duration
2163 let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2164 log::warn!(
2165 "Retrying completion request in {} seconds: {error:?}",
2166 retry_after.as_secs(),
2167 );
2168
2169 // Add a UI-only message instead of a regular message
2170 let id = self.next_message_id.post_inc();
2171 self.messages.push(Message {
2172 id,
2173 role: Role::System,
2174 segments: vec![MessageSegment::Text(retry_message)],
2175 loaded_context: LoadedContext::default(),
2176 creases: Vec::new(),
2177 is_hidden: false,
2178 ui_only: true,
2179 });
2180 cx.emit(ThreadEvent::MessageAdded(id));
2181 // Schedule the retry
2182 let thread_handle = cx.entity().downgrade();
2183
2184 cx.spawn(async move |_thread, cx| {
2185 cx.background_executor().timer(retry_after).await;
2186
2187 thread_handle
2188 .update(cx, |thread, cx| {
2189 // Retry the completion
2190 thread.send_to_model(model, intent, window, cx);
2191 })
2192 .log_err();
2193 })
2194 .detach();
2195 }
2196
2197 fn handle_retryable_error(
2198 &mut self,
2199 error: &LanguageModelCompletionError,
2200 model: Arc<dyn LanguageModel>,
2201 intent: CompletionIntent,
2202 window: Option<AnyWindowHandle>,
2203 cx: &mut Context<Self>,
2204 ) -> bool {
2205 self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2206 }
2207
2208 fn handle_retryable_error_with_delay(
2209 &mut self,
2210 error: &LanguageModelCompletionError,
2211 custom_delay: Option<Duration>,
2212 model: Arc<dyn LanguageModel>,
2213 intent: CompletionIntent,
2214 window: Option<AnyWindowHandle>,
2215 cx: &mut Context<Self>,
2216 ) -> bool {
2217 let retry_state = self.retry_state.get_or_insert(RetryState {
2218 attempt: 0,
2219 max_attempts: MAX_RETRY_ATTEMPTS,
2220 intent,
2221 });
2222
2223 retry_state.attempt += 1;
2224 let attempt = retry_state.attempt;
2225 let max_attempts = retry_state.max_attempts;
2226 let intent = retry_state.intent;
2227
2228 if attempt <= max_attempts {
2229 // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2230 let delay = if let Some(custom_delay) = custom_delay {
2231 custom_delay
2232 } else {
2233 let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2234 Duration::from_secs(delay_secs)
2235 };
2236
2237 // Add a transient message to inform the user
2238 let delay_secs = delay.as_secs();
2239 let retry_message = format!(
2240 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2241 in {delay_secs} seconds..."
2242 );
2243 log::warn!(
2244 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2245 in {delay_secs} seconds: {error:?}",
2246 );
2247
2248 // Add a UI-only message instead of a regular message
2249 let id = self.next_message_id.post_inc();
2250 self.messages.push(Message {
2251 id,
2252 role: Role::System,
2253 segments: vec![MessageSegment::Text(retry_message)],
2254 loaded_context: LoadedContext::default(),
2255 creases: Vec::new(),
2256 is_hidden: false,
2257 ui_only: true,
2258 });
2259 cx.emit(ThreadEvent::MessageAdded(id));
2260
2261 // Schedule the retry
2262 let thread_handle = cx.entity().downgrade();
2263
2264 cx.spawn(async move |_thread, cx| {
2265 cx.background_executor().timer(delay).await;
2266
2267 thread_handle
2268 .update(cx, |thread, cx| {
2269 // Retry the completion
2270 thread.send_to_model(model, intent, window, cx);
2271 })
2272 .log_err();
2273 })
2274 .detach();
2275
2276 true
2277 } else {
2278 // Max retries exceeded
2279 self.retry_state = None;
2280
2281 let notification_text = if max_attempts == 1 {
2282 "Failed after retrying.".into()
2283 } else {
2284 format!("Failed after retrying {} times.", max_attempts).into()
2285 };
2286
2287 // Stop generating since we're giving up on retrying.
2288 self.pending_completions.clear();
2289
2290 cx.emit(ThreadEvent::RetriesFailed {
2291 message: notification_text,
2292 });
2293
2294 false
2295 }
2296 }
2297
2298 pub fn start_generating_detailed_summary_if_needed(
2299 &mut self,
2300 thread_store: WeakEntity<ThreadStore>,
2301 cx: &mut Context<Self>,
2302 ) {
2303 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2304 return;
2305 };
2306
2307 match &*self.detailed_summary_rx.borrow() {
2308 DetailedSummaryState::Generating { message_id, .. }
2309 | DetailedSummaryState::Generated { message_id, .. }
2310 if *message_id == last_message_id =>
2311 {
2312 // Already up-to-date
2313 return;
2314 }
2315 _ => {}
2316 }
2317
2318 let Some(ConfiguredModel { model, provider }) =
2319 LanguageModelRegistry::read_global(cx).thread_summary_model()
2320 else {
2321 return;
2322 };
2323
2324 if !provider.is_authenticated(cx) {
2325 return;
2326 }
2327
2328 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2329
2330 let request = self.to_summarize_request(
2331 &model,
2332 CompletionIntent::ThreadContextSummarization,
2333 added_user_message.into(),
2334 cx,
2335 );
2336
2337 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2338 message_id: last_message_id,
2339 };
2340
2341 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2342 // be better to allow the old task to complete, but this would require logic for choosing
2343 // which result to prefer (the old task could complete after the new one, resulting in a
2344 // stale summary).
2345 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2346 let stream = model.stream_completion_text(request, &cx);
2347 let Some(mut messages) = stream.await.log_err() else {
2348 thread
2349 .update(cx, |thread, _cx| {
2350 *thread.detailed_summary_tx.borrow_mut() =
2351 DetailedSummaryState::NotGenerated;
2352 })
2353 .ok()?;
2354 return None;
2355 };
2356
2357 let mut new_detailed_summary = String::new();
2358
2359 while let Some(chunk) = messages.stream.next().await {
2360 if let Some(chunk) = chunk.log_err() {
2361 new_detailed_summary.push_str(&chunk);
2362 }
2363 }
2364
2365 thread
2366 .update(cx, |thread, _cx| {
2367 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2368 text: new_detailed_summary.into(),
2369 message_id: last_message_id,
2370 };
2371 })
2372 .ok()?;
2373
2374 // Save thread so its summary can be reused later
2375 if let Some(thread) = thread.upgrade() {
2376 if let Ok(Ok(save_task)) = cx.update(|cx| {
2377 thread_store
2378 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2379 }) {
2380 save_task.await.log_err();
2381 }
2382 }
2383
2384 Some(())
2385 });
2386 }
2387
2388 pub async fn wait_for_detailed_summary_or_text(
2389 this: &Entity<Self>,
2390 cx: &mut AsyncApp,
2391 ) -> Option<SharedString> {
2392 let mut detailed_summary_rx = this
2393 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2394 .ok()?;
2395 loop {
2396 match detailed_summary_rx.recv().await? {
2397 DetailedSummaryState::Generating { .. } => {}
2398 DetailedSummaryState::NotGenerated => {
2399 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2400 }
2401 DetailedSummaryState::Generated { text, .. } => return Some(text),
2402 }
2403 }
2404 }
2405
2406 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2407 self.detailed_summary_rx
2408 .borrow()
2409 .text()
2410 .unwrap_or_else(|| self.text().into())
2411 }
2412
2413 pub fn is_generating_detailed_summary(&self) -> bool {
2414 matches!(
2415 &*self.detailed_summary_rx.borrow(),
2416 DetailedSummaryState::Generating { .. }
2417 )
2418 }
2419
2420 pub fn use_pending_tools(
2421 &mut self,
2422 window: Option<AnyWindowHandle>,
2423 model: Arc<dyn LanguageModel>,
2424 cx: &mut Context<Self>,
2425 ) -> Vec<PendingToolUse> {
2426 self.auto_capture_telemetry(cx);
2427 let request =
2428 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2429 let pending_tool_uses = self
2430 .tool_use
2431 .pending_tool_uses()
2432 .into_iter()
2433 .filter(|tool_use| tool_use.status.is_idle())
2434 .cloned()
2435 .collect::<Vec<_>>();
2436
2437 for tool_use in pending_tool_uses.iter() {
2438 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2439 }
2440
2441 pending_tool_uses
2442 }
2443
2444 fn use_pending_tool(
2445 &mut self,
2446 tool_use: PendingToolUse,
2447 request: Arc<LanguageModelRequest>,
2448 model: Arc<dyn LanguageModel>,
2449 window: Option<AnyWindowHandle>,
2450 cx: &mut Context<Self>,
2451 ) {
2452 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2453 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2454 };
2455
2456 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2457 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2458 }
2459
2460 if tool.needs_confirmation(&tool_use.input, cx)
2461 && !AgentSettings::get_global(cx).always_allow_tool_actions
2462 {
2463 self.tool_use.confirm_tool_use(
2464 tool_use.id,
2465 tool_use.ui_text,
2466 tool_use.input,
2467 request,
2468 tool,
2469 );
2470 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2471 } else {
2472 self.run_tool(
2473 tool_use.id,
2474 tool_use.ui_text,
2475 tool_use.input,
2476 request,
2477 tool,
2478 model,
2479 window,
2480 cx,
2481 );
2482 }
2483 }
2484
2485 pub fn handle_hallucinated_tool_use(
2486 &mut self,
2487 tool_use_id: LanguageModelToolUseId,
2488 hallucinated_tool_name: Arc<str>,
2489 window: Option<AnyWindowHandle>,
2490 cx: &mut Context<Thread>,
2491 ) {
2492 let available_tools = self.profile.enabled_tools(cx);
2493
2494 let tool_list = available_tools
2495 .iter()
2496 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2497 .collect::<Vec<_>>()
2498 .join("\n");
2499
2500 let error_message = format!(
2501 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2502 hallucinated_tool_name, tool_list
2503 );
2504
2505 let pending_tool_use = self.tool_use.insert_tool_output(
2506 tool_use_id.clone(),
2507 hallucinated_tool_name,
2508 Err(anyhow!("Missing tool call: {error_message}")),
2509 self.configured_model.as_ref(),
2510 );
2511
2512 cx.emit(ThreadEvent::MissingToolUse {
2513 tool_use_id: tool_use_id.clone(),
2514 ui_text: error_message.into(),
2515 });
2516
2517 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2518 }
2519
2520 pub fn receive_invalid_tool_json(
2521 &mut self,
2522 tool_use_id: LanguageModelToolUseId,
2523 tool_name: Arc<str>,
2524 invalid_json: Arc<str>,
2525 error: String,
2526 window: Option<AnyWindowHandle>,
2527 cx: &mut Context<Thread>,
2528 ) {
2529 log::error!("The model returned invalid input JSON: {invalid_json}");
2530
2531 let pending_tool_use = self.tool_use.insert_tool_output(
2532 tool_use_id.clone(),
2533 tool_name,
2534 Err(anyhow!("Error parsing input JSON: {error}")),
2535 self.configured_model.as_ref(),
2536 );
2537 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2538 pending_tool_use.ui_text.clone()
2539 } else {
2540 log::error!(
2541 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2542 );
2543 format!("Unknown tool {}", tool_use_id).into()
2544 };
2545
2546 cx.emit(ThreadEvent::InvalidToolInput {
2547 tool_use_id: tool_use_id.clone(),
2548 ui_text,
2549 invalid_input_json: invalid_json,
2550 });
2551
2552 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2553 }
2554
2555 pub fn run_tool(
2556 &mut self,
2557 tool_use_id: LanguageModelToolUseId,
2558 ui_text: impl Into<SharedString>,
2559 input: serde_json::Value,
2560 request: Arc<LanguageModelRequest>,
2561 tool: Arc<dyn Tool>,
2562 model: Arc<dyn LanguageModel>,
2563 window: Option<AnyWindowHandle>,
2564 cx: &mut Context<Thread>,
2565 ) {
2566 let task =
2567 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2568 self.tool_use
2569 .run_pending_tool(tool_use_id, ui_text.into(), task);
2570 }
2571
2572 fn spawn_tool_use(
2573 &mut self,
2574 tool_use_id: LanguageModelToolUseId,
2575 request: Arc<LanguageModelRequest>,
2576 input: serde_json::Value,
2577 tool: Arc<dyn Tool>,
2578 model: Arc<dyn LanguageModel>,
2579 window: Option<AnyWindowHandle>,
2580 cx: &mut Context<Thread>,
2581 ) -> Task<()> {
2582 let tool_name: Arc<str> = tool.name().into();
2583
2584 let tool_result = tool.run(
2585 input,
2586 request,
2587 self.project.clone(),
2588 self.action_log.clone(),
2589 model,
2590 window,
2591 cx,
2592 );
2593
2594 // Store the card separately if it exists
2595 if let Some(card) = tool_result.card.clone() {
2596 self.tool_use
2597 .insert_tool_result_card(tool_use_id.clone(), card);
2598 }
2599
2600 cx.spawn({
2601 async move |thread: WeakEntity<Thread>, cx| {
2602 let output = tool_result.output.await;
2603
2604 thread
2605 .update(cx, |thread, cx| {
2606 let pending_tool_use = thread.tool_use.insert_tool_output(
2607 tool_use_id.clone(),
2608 tool_name,
2609 output,
2610 thread.configured_model.as_ref(),
2611 );
2612 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2613 })
2614 .ok();
2615 }
2616 })
2617 }
2618
2619 fn tool_finished(
2620 &mut self,
2621 tool_use_id: LanguageModelToolUseId,
2622 pending_tool_use: Option<PendingToolUse>,
2623 canceled: bool,
2624 window: Option<AnyWindowHandle>,
2625 cx: &mut Context<Self>,
2626 ) {
2627 if self.all_tools_finished() {
2628 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2629 if !canceled {
2630 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2631 }
2632 self.auto_capture_telemetry(cx);
2633 }
2634 }
2635
2636 cx.emit(ThreadEvent::ToolFinished {
2637 tool_use_id,
2638 pending_tool_use,
2639 });
2640 }
2641
2642 /// Cancels the last pending completion, if there are any pending.
2643 ///
2644 /// Returns whether a completion was canceled.
2645 pub fn cancel_last_completion(
2646 &mut self,
2647 window: Option<AnyWindowHandle>,
2648 cx: &mut Context<Self>,
2649 ) -> bool {
2650 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2651
2652 self.retry_state = None;
2653
2654 for pending_tool_use in self.tool_use.cancel_pending() {
2655 canceled = true;
2656 self.tool_finished(
2657 pending_tool_use.id.clone(),
2658 Some(pending_tool_use),
2659 true,
2660 window,
2661 cx,
2662 );
2663 }
2664
2665 if canceled {
2666 cx.emit(ThreadEvent::CompletionCanceled);
2667
2668 // When canceled, we always want to insert the checkpoint.
2669 // (We skip over finalize_pending_checkpoint, because it
2670 // would conclude we didn't have anything to insert here.)
2671 if let Some(checkpoint) = self.pending_checkpoint.take() {
2672 self.insert_checkpoint(checkpoint, cx);
2673 }
2674 } else {
2675 self.finalize_pending_checkpoint(cx);
2676 }
2677
2678 canceled
2679 }
2680
2681 /// Signals that any in-progress editing should be canceled.
2682 ///
2683 /// This method is used to notify listeners (like ActiveThread) that
2684 /// they should cancel any editing operations.
2685 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2686 cx.emit(ThreadEvent::CancelEditing);
2687 }
2688
2689 pub fn feedback(&self) -> Option<ThreadFeedback> {
2690 self.feedback
2691 }
2692
2693 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2694 self.message_feedback.get(&message_id).copied()
2695 }
2696
2697 pub fn report_message_feedback(
2698 &mut self,
2699 message_id: MessageId,
2700 feedback: ThreadFeedback,
2701 cx: &mut Context<Self>,
2702 ) -> Task<Result<()>> {
2703 if self.message_feedback.get(&message_id) == Some(&feedback) {
2704 return Task::ready(Ok(()));
2705 }
2706
2707 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2708 let serialized_thread = self.serialize(cx);
2709 let thread_id = self.id().clone();
2710 let client = self.project.read(cx).client();
2711
2712 let enabled_tool_names: Vec<String> = self
2713 .profile
2714 .enabled_tools(cx)
2715 .iter()
2716 .map(|(name, _)| name.clone().into())
2717 .collect();
2718
2719 self.message_feedback.insert(message_id, feedback);
2720
2721 cx.notify();
2722
2723 let message_content = self
2724 .message(message_id)
2725 .map(|msg| msg.to_string())
2726 .unwrap_or_default();
2727
2728 cx.background_spawn(async move {
2729 let final_project_snapshot = final_project_snapshot.await;
2730 let serialized_thread = serialized_thread.await?;
2731 let thread_data =
2732 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2733
2734 let rating = match feedback {
2735 ThreadFeedback::Positive => "positive",
2736 ThreadFeedback::Negative => "negative",
2737 };
2738 telemetry::event!(
2739 "Assistant Thread Rated",
2740 rating,
2741 thread_id,
2742 enabled_tool_names,
2743 message_id = message_id.0,
2744 message_content,
2745 thread_data,
2746 final_project_snapshot
2747 );
2748 client.telemetry().flush_events().await;
2749
2750 Ok(())
2751 })
2752 }
2753
2754 pub fn report_feedback(
2755 &mut self,
2756 feedback: ThreadFeedback,
2757 cx: &mut Context<Self>,
2758 ) -> Task<Result<()>> {
2759 let last_assistant_message_id = self
2760 .messages
2761 .iter()
2762 .rev()
2763 .find(|msg| msg.role == Role::Assistant)
2764 .map(|msg| msg.id);
2765
2766 if let Some(message_id) = last_assistant_message_id {
2767 self.report_message_feedback(message_id, feedback, cx)
2768 } else {
2769 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2770 let serialized_thread = self.serialize(cx);
2771 let thread_id = self.id().clone();
2772 let client = self.project.read(cx).client();
2773 self.feedback = Some(feedback);
2774 cx.notify();
2775
2776 cx.background_spawn(async move {
2777 let final_project_snapshot = final_project_snapshot.await;
2778 let serialized_thread = serialized_thread.await?;
2779 let thread_data = serde_json::to_value(serialized_thread)
2780 .unwrap_or_else(|_| serde_json::Value::Null);
2781
2782 let rating = match feedback {
2783 ThreadFeedback::Positive => "positive",
2784 ThreadFeedback::Negative => "negative",
2785 };
2786 telemetry::event!(
2787 "Assistant Thread Rated",
2788 rating,
2789 thread_id,
2790 thread_data,
2791 final_project_snapshot
2792 );
2793 client.telemetry().flush_events().await;
2794
2795 Ok(())
2796 })
2797 }
2798 }
2799
2800 /// Create a snapshot of the current project state including git information and unsaved buffers.
2801 fn project_snapshot(
2802 project: Entity<Project>,
2803 cx: &mut Context<Self>,
2804 ) -> Task<Arc<ProjectSnapshot>> {
2805 let git_store = project.read(cx).git_store().clone();
2806 let worktree_snapshots: Vec<_> = project
2807 .read(cx)
2808 .visible_worktrees(cx)
2809 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2810 .collect();
2811
2812 cx.spawn(async move |_, cx| {
2813 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2814
2815 let mut unsaved_buffers = Vec::new();
2816 cx.update(|app_cx| {
2817 let buffer_store = project.read(app_cx).buffer_store();
2818 for buffer_handle in buffer_store.read(app_cx).buffers() {
2819 let buffer = buffer_handle.read(app_cx);
2820 if buffer.is_dirty() {
2821 if let Some(file) = buffer.file() {
2822 let path = file.path().to_string_lossy().to_string();
2823 unsaved_buffers.push(path);
2824 }
2825 }
2826 }
2827 })
2828 .ok();
2829
2830 Arc::new(ProjectSnapshot {
2831 worktree_snapshots,
2832 unsaved_buffer_paths: unsaved_buffers,
2833 timestamp: Utc::now(),
2834 })
2835 })
2836 }
2837
2838 fn worktree_snapshot(
2839 worktree: Entity<project::Worktree>,
2840 git_store: Entity<GitStore>,
2841 cx: &App,
2842 ) -> Task<WorktreeSnapshot> {
2843 cx.spawn(async move |cx| {
2844 // Get worktree path and snapshot
2845 let worktree_info = cx.update(|app_cx| {
2846 let worktree = worktree.read(app_cx);
2847 let path = worktree.abs_path().to_string_lossy().to_string();
2848 let snapshot = worktree.snapshot();
2849 (path, snapshot)
2850 });
2851
2852 let Ok((worktree_path, _snapshot)) = worktree_info else {
2853 return WorktreeSnapshot {
2854 worktree_path: String::new(),
2855 git_state: None,
2856 };
2857 };
2858
2859 let git_state = git_store
2860 .update(cx, |git_store, cx| {
2861 git_store
2862 .repositories()
2863 .values()
2864 .find(|repo| {
2865 repo.read(cx)
2866 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2867 .is_some()
2868 })
2869 .cloned()
2870 })
2871 .ok()
2872 .flatten()
2873 .map(|repo| {
2874 repo.update(cx, |repo, _| {
2875 let current_branch =
2876 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2877 repo.send_job(None, |state, _| async move {
2878 let RepositoryState::Local { backend, .. } = state else {
2879 return GitState {
2880 remote_url: None,
2881 head_sha: None,
2882 current_branch,
2883 diff: None,
2884 };
2885 };
2886
2887 let remote_url = backend.remote_url("origin");
2888 let head_sha = backend.head_sha().await;
2889 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2890
2891 GitState {
2892 remote_url,
2893 head_sha,
2894 current_branch,
2895 diff,
2896 }
2897 })
2898 })
2899 });
2900
2901 let git_state = match git_state {
2902 Some(git_state) => match git_state.ok() {
2903 Some(git_state) => git_state.await.ok(),
2904 None => None,
2905 },
2906 None => None,
2907 };
2908
2909 WorktreeSnapshot {
2910 worktree_path,
2911 git_state,
2912 }
2913 })
2914 }
2915
2916 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2917 let mut markdown = Vec::new();
2918
2919 let summary = self.summary().or_default();
2920 writeln!(markdown, "# {summary}\n")?;
2921
2922 for message in self.messages() {
2923 writeln!(
2924 markdown,
2925 "## {role}\n",
2926 role = match message.role {
2927 Role::User => "User",
2928 Role::Assistant => "Agent",
2929 Role::System => "System",
2930 }
2931 )?;
2932
2933 if !message.loaded_context.text.is_empty() {
2934 writeln!(markdown, "{}", message.loaded_context.text)?;
2935 }
2936
2937 if !message.loaded_context.images.is_empty() {
2938 writeln!(
2939 markdown,
2940 "\n{} images attached as context.\n",
2941 message.loaded_context.images.len()
2942 )?;
2943 }
2944
2945 for segment in &message.segments {
2946 match segment {
2947 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2948 MessageSegment::Thinking { text, .. } => {
2949 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2950 }
2951 MessageSegment::RedactedThinking(_) => {}
2952 }
2953 }
2954
2955 for tool_use in self.tool_uses_for_message(message.id, cx) {
2956 writeln!(
2957 markdown,
2958 "**Use Tool: {} ({})**",
2959 tool_use.name, tool_use.id
2960 )?;
2961 writeln!(markdown, "```json")?;
2962 writeln!(
2963 markdown,
2964 "{}",
2965 serde_json::to_string_pretty(&tool_use.input)?
2966 )?;
2967 writeln!(markdown, "```")?;
2968 }
2969
2970 for tool_result in self.tool_results_for_message(message.id) {
2971 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2972 if tool_result.is_error {
2973 write!(markdown, " (Error)")?;
2974 }
2975
2976 writeln!(markdown, "**\n")?;
2977 match &tool_result.content {
2978 LanguageModelToolResultContent::Text(text) => {
2979 writeln!(markdown, "{text}")?;
2980 }
2981 LanguageModelToolResultContent::Image(image) => {
2982 writeln!(markdown, "", image.source)?;
2983 }
2984 }
2985
2986 if let Some(output) = tool_result.output.as_ref() {
2987 writeln!(
2988 markdown,
2989 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2990 serde_json::to_string_pretty(output)?
2991 )?;
2992 }
2993 }
2994 }
2995
2996 Ok(String::from_utf8_lossy(&markdown).to_string())
2997 }
2998
2999 pub fn keep_edits_in_range(
3000 &mut self,
3001 buffer: Entity<language::Buffer>,
3002 buffer_range: Range<language::Anchor>,
3003 cx: &mut Context<Self>,
3004 ) {
3005 self.action_log.update(cx, |action_log, cx| {
3006 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3007 });
3008 }
3009
3010 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3011 self.action_log
3012 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3013 }
3014
3015 pub fn reject_edits_in_ranges(
3016 &mut self,
3017 buffer: Entity<language::Buffer>,
3018 buffer_ranges: Vec<Range<language::Anchor>>,
3019 cx: &mut Context<Self>,
3020 ) -> Task<Result<()>> {
3021 self.action_log.update(cx, |action_log, cx| {
3022 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3023 })
3024 }
3025
3026 pub fn action_log(&self) -> &Entity<ActionLog> {
3027 &self.action_log
3028 }
3029
3030 pub fn project(&self) -> &Entity<Project> {
3031 &self.project
3032 }
3033
3034 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3035 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3036 return;
3037 }
3038
3039 let now = Instant::now();
3040 if let Some(last) = self.last_auto_capture_at {
3041 if now.duration_since(last).as_secs() < 10 {
3042 return;
3043 }
3044 }
3045
3046 self.last_auto_capture_at = Some(now);
3047
3048 let thread_id = self.id().clone();
3049 let github_login = self
3050 .project
3051 .read(cx)
3052 .user_store()
3053 .read(cx)
3054 .current_user()
3055 .map(|user| user.github_login.clone());
3056 let client = self.project.read(cx).client();
3057 let serialize_task = self.serialize(cx);
3058
3059 cx.background_executor()
3060 .spawn(async move {
3061 if let Ok(serialized_thread) = serialize_task.await {
3062 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3063 telemetry::event!(
3064 "Agent Thread Auto-Captured",
3065 thread_id = thread_id.to_string(),
3066 thread_data = thread_data,
3067 auto_capture_reason = "tracked_user",
3068 github_login = github_login
3069 );
3070
3071 client.telemetry().flush_events().await;
3072 }
3073 }
3074 })
3075 .detach();
3076 }
3077
3078 pub fn cumulative_token_usage(&self) -> TokenUsage {
3079 self.cumulative_token_usage
3080 }
3081
3082 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3083 let Some(model) = self.configured_model.as_ref() else {
3084 return TotalTokenUsage::default();
3085 };
3086
3087 let max = model.model.max_token_count();
3088
3089 let index = self
3090 .messages
3091 .iter()
3092 .position(|msg| msg.id == message_id)
3093 .unwrap_or(0);
3094
3095 if index == 0 {
3096 return TotalTokenUsage { total: 0, max };
3097 }
3098
3099 let token_usage = &self
3100 .request_token_usage
3101 .get(index - 1)
3102 .cloned()
3103 .unwrap_or_default();
3104
3105 TotalTokenUsage {
3106 total: token_usage.total_tokens(),
3107 max,
3108 }
3109 }
3110
3111 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3112 let model = self.configured_model.as_ref()?;
3113
3114 let max = model.model.max_token_count();
3115
3116 if let Some(exceeded_error) = &self.exceeded_window_error {
3117 if model.model.id() == exceeded_error.model_id {
3118 return Some(TotalTokenUsage {
3119 total: exceeded_error.token_count,
3120 max,
3121 });
3122 }
3123 }
3124
3125 let total = self
3126 .token_usage_at_last_message()
3127 .unwrap_or_default()
3128 .total_tokens();
3129
3130 Some(TotalTokenUsage { total, max })
3131 }
3132
3133 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3134 self.request_token_usage
3135 .get(self.messages.len().saturating_sub(1))
3136 .or_else(|| self.request_token_usage.last())
3137 .cloned()
3138 }
3139
3140 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3141 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3142 self.request_token_usage
3143 .resize(self.messages.len(), placeholder);
3144
3145 if let Some(last) = self.request_token_usage.last_mut() {
3146 *last = token_usage;
3147 }
3148 }
3149
3150 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3151 self.project.update(cx, |project, cx| {
3152 project.user_store().update(cx, |user_store, cx| {
3153 user_store.update_model_request_usage(
3154 ModelRequestUsage(RequestUsage {
3155 amount: amount as i32,
3156 limit,
3157 }),
3158 cx,
3159 )
3160 })
3161 });
3162 }
3163
3164 pub fn deny_tool_use(
3165 &mut self,
3166 tool_use_id: LanguageModelToolUseId,
3167 tool_name: Arc<str>,
3168 window: Option<AnyWindowHandle>,
3169 cx: &mut Context<Self>,
3170 ) {
3171 let err = Err(anyhow::anyhow!(
3172 "Permission to run tool action denied by user"
3173 ));
3174
3175 self.tool_use.insert_tool_output(
3176 tool_use_id.clone(),
3177 tool_name,
3178 err,
3179 self.configured_model.as_ref(),
3180 );
3181 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3182 }
3183}
3184
3185#[derive(Debug, Clone, Error)]
3186pub enum ThreadError {
3187 #[error("Payment required")]
3188 PaymentRequired,
3189 #[error("Model request limit reached")]
3190 ModelRequestLimitReached { plan: Plan },
3191 #[error("Message {header}: {message}")]
3192 Message {
3193 header: SharedString,
3194 message: SharedString,
3195 },
3196}
3197
3198#[derive(Debug, Clone)]
3199pub enum ThreadEvent {
3200 ShowError(ThreadError),
3201 StreamedCompletion,
3202 ReceivedTextChunk,
3203 NewRequest,
3204 StreamedAssistantText(MessageId, String),
3205 StreamedAssistantThinking(MessageId, String),
3206 StreamedToolUse {
3207 tool_use_id: LanguageModelToolUseId,
3208 ui_text: Arc<str>,
3209 input: serde_json::Value,
3210 },
3211 MissingToolUse {
3212 tool_use_id: LanguageModelToolUseId,
3213 ui_text: Arc<str>,
3214 },
3215 InvalidToolInput {
3216 tool_use_id: LanguageModelToolUseId,
3217 ui_text: Arc<str>,
3218 invalid_input_json: Arc<str>,
3219 },
3220 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3221 MessageAdded(MessageId),
3222 MessageEdited(MessageId),
3223 MessageDeleted(MessageId),
3224 SummaryGenerated,
3225 SummaryChanged,
3226 UsePendingTools {
3227 tool_uses: Vec<PendingToolUse>,
3228 },
3229 ToolFinished {
3230 #[allow(unused)]
3231 tool_use_id: LanguageModelToolUseId,
3232 /// The pending tool use that corresponds to this tool.
3233 pending_tool_use: Option<PendingToolUse>,
3234 },
3235 CheckpointChanged,
3236 ToolConfirmationNeeded,
3237 ToolUseLimitReached,
3238 CancelEditing,
3239 CompletionCanceled,
3240 ProfileChanged,
3241 RetriesFailed {
3242 message: SharedString,
3243 },
3244}
3245
3246impl EventEmitter<ThreadEvent> for Thread {}
3247
3248struct PendingCompletion {
3249 id: usize,
3250 queue_state: QueueState,
3251 _task: Task<()>,
3252}
3253
3254#[cfg(test)]
3255mod tests {
3256 use super::*;
3257 use crate::{
3258 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3259 };
3260
3261 // Test-specific constants
3262 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3263 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3264 use assistant_tool::ToolRegistry;
3265 use assistant_tools;
3266 use futures::StreamExt;
3267 use futures::future::BoxFuture;
3268 use futures::stream::BoxStream;
3269 use gpui::TestAppContext;
3270 use http_client;
3271 use indoc::indoc;
3272 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3273 use language_model::{
3274 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3275 LanguageModelProviderName, LanguageModelToolChoice,
3276 };
3277 use parking_lot::Mutex;
3278 use project::{FakeFs, Project};
3279 use prompt_store::PromptBuilder;
3280 use serde_json::json;
3281 use settings::{Settings, SettingsStore};
3282 use std::sync::Arc;
3283 use std::time::Duration;
3284 use theme::ThemeSettings;
3285 use util::path;
3286 use workspace::Workspace;
3287
3288 #[gpui::test]
3289 async fn test_message_with_context(cx: &mut TestAppContext) {
3290 init_test_settings(cx);
3291
3292 let project = create_test_project(
3293 cx,
3294 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3295 )
3296 .await;
3297
3298 let (_workspace, _thread_store, thread, context_store, model) =
3299 setup_test_environment(cx, project.clone()).await;
3300
3301 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3302 .await
3303 .unwrap();
3304
3305 let context =
3306 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3307 let loaded_context = cx
3308 .update(|cx| load_context(vec![context], &project, &None, cx))
3309 .await;
3310
3311 // Insert user message with context
3312 let message_id = thread.update(cx, |thread, cx| {
3313 thread.insert_user_message(
3314 "Please explain this code",
3315 loaded_context,
3316 None,
3317 Vec::new(),
3318 cx,
3319 )
3320 });
3321
3322 // Check content and context in message object
3323 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3324
3325 // Use different path format strings based on platform for the test
3326 #[cfg(windows)]
3327 let path_part = r"test\code.rs";
3328 #[cfg(not(windows))]
3329 let path_part = "test/code.rs";
3330
3331 let expected_context = format!(
3332 r#"
3333<context>
3334The following items were attached by the user. They are up-to-date and don't need to be re-read.
3335
3336<files>
3337```rs {path_part}
3338fn main() {{
3339 println!("Hello, world!");
3340}}
3341```
3342</files>
3343</context>
3344"#
3345 );
3346
3347 assert_eq!(message.role, Role::User);
3348 assert_eq!(message.segments.len(), 1);
3349 assert_eq!(
3350 message.segments[0],
3351 MessageSegment::Text("Please explain this code".to_string())
3352 );
3353 assert_eq!(message.loaded_context.text, expected_context);
3354
3355 // Check message in request
3356 let request = thread.update(cx, |thread, cx| {
3357 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3358 });
3359
3360 assert_eq!(request.messages.len(), 2);
3361 let expected_full_message = format!("{}Please explain this code", expected_context);
3362 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3363 }
3364
3365 #[gpui::test]
3366 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3367 init_test_settings(cx);
3368
3369 let project = create_test_project(
3370 cx,
3371 json!({
3372 "file1.rs": "fn function1() {}\n",
3373 "file2.rs": "fn function2() {}\n",
3374 "file3.rs": "fn function3() {}\n",
3375 "file4.rs": "fn function4() {}\n",
3376 }),
3377 )
3378 .await;
3379
3380 let (_, _thread_store, thread, context_store, model) =
3381 setup_test_environment(cx, project.clone()).await;
3382
3383 // First message with context 1
3384 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3385 .await
3386 .unwrap();
3387 let new_contexts = context_store.update(cx, |store, cx| {
3388 store.new_context_for_thread(thread.read(cx), None)
3389 });
3390 assert_eq!(new_contexts.len(), 1);
3391 let loaded_context = cx
3392 .update(|cx| load_context(new_contexts, &project, &None, cx))
3393 .await;
3394 let message1_id = thread.update(cx, |thread, cx| {
3395 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3396 });
3397
3398 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3399 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3400 .await
3401 .unwrap();
3402 let new_contexts = context_store.update(cx, |store, cx| {
3403 store.new_context_for_thread(thread.read(cx), None)
3404 });
3405 assert_eq!(new_contexts.len(), 1);
3406 let loaded_context = cx
3407 .update(|cx| load_context(new_contexts, &project, &None, cx))
3408 .await;
3409 let message2_id = thread.update(cx, |thread, cx| {
3410 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3411 });
3412
3413 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3414 //
3415 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3416 .await
3417 .unwrap();
3418 let new_contexts = context_store.update(cx, |store, cx| {
3419 store.new_context_for_thread(thread.read(cx), None)
3420 });
3421 assert_eq!(new_contexts.len(), 1);
3422 let loaded_context = cx
3423 .update(|cx| load_context(new_contexts, &project, &None, cx))
3424 .await;
3425 let message3_id = thread.update(cx, |thread, cx| {
3426 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3427 });
3428
3429 // Check what contexts are included in each message
3430 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3431 (
3432 thread.message(message1_id).unwrap().clone(),
3433 thread.message(message2_id).unwrap().clone(),
3434 thread.message(message3_id).unwrap().clone(),
3435 )
3436 });
3437
3438 // First message should include context 1
3439 assert!(message1.loaded_context.text.contains("file1.rs"));
3440
3441 // Second message should include only context 2 (not 1)
3442 assert!(!message2.loaded_context.text.contains("file1.rs"));
3443 assert!(message2.loaded_context.text.contains("file2.rs"));
3444
3445 // Third message should include only context 3 (not 1 or 2)
3446 assert!(!message3.loaded_context.text.contains("file1.rs"));
3447 assert!(!message3.loaded_context.text.contains("file2.rs"));
3448 assert!(message3.loaded_context.text.contains("file3.rs"));
3449
3450 // Check entire request to make sure all contexts are properly included
3451 let request = thread.update(cx, |thread, cx| {
3452 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3453 });
3454
3455 // The request should contain all 3 messages
3456 assert_eq!(request.messages.len(), 4);
3457
3458 // Check that the contexts are properly formatted in each message
3459 assert!(request.messages[1].string_contents().contains("file1.rs"));
3460 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3461 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3462
3463 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3464 assert!(request.messages[2].string_contents().contains("file2.rs"));
3465 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3466
3467 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3468 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3469 assert!(request.messages[3].string_contents().contains("file3.rs"));
3470
3471 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3472 .await
3473 .unwrap();
3474 let new_contexts = context_store.update(cx, |store, cx| {
3475 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3476 });
3477 assert_eq!(new_contexts.len(), 3);
3478 let loaded_context = cx
3479 .update(|cx| load_context(new_contexts, &project, &None, cx))
3480 .await
3481 .loaded_context;
3482
3483 assert!(!loaded_context.text.contains("file1.rs"));
3484 assert!(loaded_context.text.contains("file2.rs"));
3485 assert!(loaded_context.text.contains("file3.rs"));
3486 assert!(loaded_context.text.contains("file4.rs"));
3487
3488 let new_contexts = context_store.update(cx, |store, cx| {
3489 // Remove file4.rs
3490 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3491 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3492 });
3493 assert_eq!(new_contexts.len(), 2);
3494 let loaded_context = cx
3495 .update(|cx| load_context(new_contexts, &project, &None, cx))
3496 .await
3497 .loaded_context;
3498
3499 assert!(!loaded_context.text.contains("file1.rs"));
3500 assert!(loaded_context.text.contains("file2.rs"));
3501 assert!(loaded_context.text.contains("file3.rs"));
3502 assert!(!loaded_context.text.contains("file4.rs"));
3503
3504 let new_contexts = context_store.update(cx, |store, cx| {
3505 // Remove file3.rs
3506 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3507 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3508 });
3509 assert_eq!(new_contexts.len(), 1);
3510 let loaded_context = cx
3511 .update(|cx| load_context(new_contexts, &project, &None, cx))
3512 .await
3513 .loaded_context;
3514
3515 assert!(!loaded_context.text.contains("file1.rs"));
3516 assert!(loaded_context.text.contains("file2.rs"));
3517 assert!(!loaded_context.text.contains("file3.rs"));
3518 assert!(!loaded_context.text.contains("file4.rs"));
3519 }
3520
3521 #[gpui::test]
3522 async fn test_message_without_files(cx: &mut TestAppContext) {
3523 init_test_settings(cx);
3524
3525 let project = create_test_project(
3526 cx,
3527 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3528 )
3529 .await;
3530
3531 let (_, _thread_store, thread, _context_store, model) =
3532 setup_test_environment(cx, project.clone()).await;
3533
3534 // Insert user message without any context (empty context vector)
3535 let message_id = thread.update(cx, |thread, cx| {
3536 thread.insert_user_message(
3537 "What is the best way to learn Rust?",
3538 ContextLoadResult::default(),
3539 None,
3540 Vec::new(),
3541 cx,
3542 )
3543 });
3544
3545 // Check content and context in message object
3546 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3547
3548 // Context should be empty when no files are included
3549 assert_eq!(message.role, Role::User);
3550 assert_eq!(message.segments.len(), 1);
3551 assert_eq!(
3552 message.segments[0],
3553 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3554 );
3555 assert_eq!(message.loaded_context.text, "");
3556
3557 // Check message in request
3558 let request = thread.update(cx, |thread, cx| {
3559 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3560 });
3561
3562 assert_eq!(request.messages.len(), 2);
3563 assert_eq!(
3564 request.messages[1].string_contents(),
3565 "What is the best way to learn Rust?"
3566 );
3567
3568 // Add second message, also without context
3569 let message2_id = thread.update(cx, |thread, cx| {
3570 thread.insert_user_message(
3571 "Are there any good books?",
3572 ContextLoadResult::default(),
3573 None,
3574 Vec::new(),
3575 cx,
3576 )
3577 });
3578
3579 let message2 =
3580 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3581 assert_eq!(message2.loaded_context.text, "");
3582
3583 // Check that both messages appear in the request
3584 let request = thread.update(cx, |thread, cx| {
3585 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3586 });
3587
3588 assert_eq!(request.messages.len(), 3);
3589 assert_eq!(
3590 request.messages[1].string_contents(),
3591 "What is the best way to learn Rust?"
3592 );
3593 assert_eq!(
3594 request.messages[2].string_contents(),
3595 "Are there any good books?"
3596 );
3597 }
3598
3599 #[gpui::test]
3600 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3601 init_test_settings(cx);
3602
3603 let project = create_test_project(
3604 cx,
3605 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3606 )
3607 .await;
3608
3609 let (_workspace, _thread_store, thread, context_store, model) =
3610 setup_test_environment(cx, project.clone()).await;
3611
3612 // Add a buffer to the context. This will be a tracked buffer
3613 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3614 .await
3615 .unwrap();
3616
3617 let context = context_store
3618 .read_with(cx, |store, _| store.context().next().cloned())
3619 .unwrap();
3620 let loaded_context = cx
3621 .update(|cx| load_context(vec![context], &project, &None, cx))
3622 .await;
3623
3624 // Insert user message and assistant response
3625 thread.update(cx, |thread, cx| {
3626 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3627 thread.insert_assistant_message(
3628 vec![MessageSegment::Text("This code prints 42.".into())],
3629 cx,
3630 );
3631 });
3632
3633 // We shouldn't have a stale buffer notification yet
3634 let notification = thread.read_with(cx, |thread, _| {
3635 find_tool_use(thread, "project_notifications")
3636 });
3637 assert!(
3638 notification.is_none(),
3639 "Should not have stale buffer notification before buffer is modified"
3640 );
3641
3642 // Modify the buffer
3643 buffer.update(cx, |buffer, cx| {
3644 buffer.edit(
3645 [(1..1, "\n println!(\"Added a new line\");\n")],
3646 None,
3647 cx,
3648 );
3649 });
3650
3651 // Insert another user message
3652 thread.update(cx, |thread, cx| {
3653 thread.insert_user_message(
3654 "What does the code do now?",
3655 ContextLoadResult::default(),
3656 None,
3657 Vec::new(),
3658 cx,
3659 )
3660 });
3661
3662 // Check for the stale buffer warning
3663 thread.update(cx, |thread, cx| {
3664 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3665 });
3666
3667 let Some(notification_result) = thread.read_with(cx, |thread, _cx| {
3668 find_tool_use(thread, "project_notifications")
3669 }) else {
3670 panic!("Should have a `project_notifications` tool use");
3671 };
3672
3673 let Some(notification_content) = notification_result.content.to_str() else {
3674 panic!("`project_notifications` should return text");
3675 };
3676
3677 let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3678
3679 These files have changed since the last read:
3680 - code.rs
3681 "};
3682 assert_eq!(notification_content, expected_content);
3683 }
3684
3685 fn find_tool_use(thread: &Thread, tool_name: &str) -> Option<LanguageModelToolResult> {
3686 thread
3687 .messages()
3688 .filter_map(|message| {
3689 thread
3690 .tool_results_for_message(message.id)
3691 .into_iter()
3692 .find(|result| result.tool_name == tool_name.into())
3693 })
3694 .next()
3695 .cloned()
3696 }
3697
3698 #[gpui::test]
3699 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3700 init_test_settings(cx);
3701
3702 let project = create_test_project(
3703 cx,
3704 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3705 )
3706 .await;
3707
3708 let (_workspace, thread_store, thread, _context_store, _model) =
3709 setup_test_environment(cx, project.clone()).await;
3710
3711 // Check that we are starting with the default profile
3712 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3713 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3714 assert_eq!(
3715 profile,
3716 AgentProfile::new(AgentProfileId::default(), tool_set)
3717 );
3718 }
3719
3720 #[gpui::test]
3721 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3722 init_test_settings(cx);
3723
3724 let project = create_test_project(
3725 cx,
3726 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3727 )
3728 .await;
3729
3730 let (_workspace, thread_store, thread, _context_store, _model) =
3731 setup_test_environment(cx, project.clone()).await;
3732
3733 // Profile gets serialized with default values
3734 let serialized = thread
3735 .update(cx, |thread, cx| thread.serialize(cx))
3736 .await
3737 .unwrap();
3738
3739 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3740
3741 let deserialized = cx.update(|cx| {
3742 thread.update(cx, |thread, cx| {
3743 Thread::deserialize(
3744 thread.id.clone(),
3745 serialized,
3746 thread.project.clone(),
3747 thread.tools.clone(),
3748 thread.prompt_builder.clone(),
3749 thread.project_context.clone(),
3750 None,
3751 cx,
3752 )
3753 })
3754 });
3755 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3756
3757 assert_eq!(
3758 deserialized.profile,
3759 AgentProfile::new(AgentProfileId::default(), tool_set)
3760 );
3761 }
3762
3763 #[gpui::test]
3764 async fn test_temperature_setting(cx: &mut TestAppContext) {
3765 init_test_settings(cx);
3766
3767 let project = create_test_project(
3768 cx,
3769 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3770 )
3771 .await;
3772
3773 let (_workspace, _thread_store, thread, _context_store, model) =
3774 setup_test_environment(cx, project.clone()).await;
3775
3776 // Both model and provider
3777 cx.update(|cx| {
3778 AgentSettings::override_global(
3779 AgentSettings {
3780 model_parameters: vec![LanguageModelParameters {
3781 provider: Some(model.provider_id().0.to_string().into()),
3782 model: Some(model.id().0.clone()),
3783 temperature: Some(0.66),
3784 }],
3785 ..AgentSettings::get_global(cx).clone()
3786 },
3787 cx,
3788 );
3789 });
3790
3791 let request = thread.update(cx, |thread, cx| {
3792 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3793 });
3794 assert_eq!(request.temperature, Some(0.66));
3795
3796 // Only model
3797 cx.update(|cx| {
3798 AgentSettings::override_global(
3799 AgentSettings {
3800 model_parameters: vec![LanguageModelParameters {
3801 provider: None,
3802 model: Some(model.id().0.clone()),
3803 temperature: Some(0.66),
3804 }],
3805 ..AgentSettings::get_global(cx).clone()
3806 },
3807 cx,
3808 );
3809 });
3810
3811 let request = thread.update(cx, |thread, cx| {
3812 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3813 });
3814 assert_eq!(request.temperature, Some(0.66));
3815
3816 // Only provider
3817 cx.update(|cx| {
3818 AgentSettings::override_global(
3819 AgentSettings {
3820 model_parameters: vec![LanguageModelParameters {
3821 provider: Some(model.provider_id().0.to_string().into()),
3822 model: None,
3823 temperature: Some(0.66),
3824 }],
3825 ..AgentSettings::get_global(cx).clone()
3826 },
3827 cx,
3828 );
3829 });
3830
3831 let request = thread.update(cx, |thread, cx| {
3832 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3833 });
3834 assert_eq!(request.temperature, Some(0.66));
3835
3836 // Same model name, different provider
3837 cx.update(|cx| {
3838 AgentSettings::override_global(
3839 AgentSettings {
3840 model_parameters: vec![LanguageModelParameters {
3841 provider: Some("anthropic".into()),
3842 model: Some(model.id().0.clone()),
3843 temperature: Some(0.66),
3844 }],
3845 ..AgentSettings::get_global(cx).clone()
3846 },
3847 cx,
3848 );
3849 });
3850
3851 let request = thread.update(cx, |thread, cx| {
3852 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3853 });
3854 assert_eq!(request.temperature, None);
3855 }
3856
3857 #[gpui::test]
3858 async fn test_thread_summary(cx: &mut TestAppContext) {
3859 init_test_settings(cx);
3860
3861 let project = create_test_project(cx, json!({})).await;
3862
3863 let (_, _thread_store, thread, _context_store, model) =
3864 setup_test_environment(cx, project.clone()).await;
3865
3866 // Initial state should be pending
3867 thread.read_with(cx, |thread, _| {
3868 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3869 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3870 });
3871
3872 // Manually setting the summary should not be allowed in this state
3873 thread.update(cx, |thread, cx| {
3874 thread.set_summary("This should not work", cx);
3875 });
3876
3877 thread.read_with(cx, |thread, _| {
3878 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3879 });
3880
3881 // Send a message
3882 thread.update(cx, |thread, cx| {
3883 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3884 thread.send_to_model(
3885 model.clone(),
3886 CompletionIntent::ThreadSummarization,
3887 None,
3888 cx,
3889 );
3890 });
3891
3892 let fake_model = model.as_fake();
3893 simulate_successful_response(&fake_model, cx);
3894
3895 // Should start generating summary when there are >= 2 messages
3896 thread.read_with(cx, |thread, _| {
3897 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3898 });
3899
3900 // Should not be able to set the summary while generating
3901 thread.update(cx, |thread, cx| {
3902 thread.set_summary("This should not work either", cx);
3903 });
3904
3905 thread.read_with(cx, |thread, _| {
3906 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3907 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3908 });
3909
3910 cx.run_until_parked();
3911 fake_model.stream_last_completion_response("Brief");
3912 fake_model.stream_last_completion_response(" Introduction");
3913 fake_model.end_last_completion_stream();
3914 cx.run_until_parked();
3915
3916 // Summary should be set
3917 thread.read_with(cx, |thread, _| {
3918 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3919 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3920 });
3921
3922 // Now we should be able to set a summary
3923 thread.update(cx, |thread, cx| {
3924 thread.set_summary("Brief Intro", cx);
3925 });
3926
3927 thread.read_with(cx, |thread, _| {
3928 assert_eq!(thread.summary().or_default(), "Brief Intro");
3929 });
3930
3931 // Test setting an empty summary (should default to DEFAULT)
3932 thread.update(cx, |thread, cx| {
3933 thread.set_summary("", cx);
3934 });
3935
3936 thread.read_with(cx, |thread, _| {
3937 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3938 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3939 });
3940 }
3941
3942 #[gpui::test]
3943 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3944 init_test_settings(cx);
3945
3946 let project = create_test_project(cx, json!({})).await;
3947
3948 let (_, _thread_store, thread, _context_store, model) =
3949 setup_test_environment(cx, project.clone()).await;
3950
3951 test_summarize_error(&model, &thread, cx);
3952
3953 // Now we should be able to set a summary
3954 thread.update(cx, |thread, cx| {
3955 thread.set_summary("Brief Intro", cx);
3956 });
3957
3958 thread.read_with(cx, |thread, _| {
3959 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3960 assert_eq!(thread.summary().or_default(), "Brief Intro");
3961 });
3962 }
3963
3964 #[gpui::test]
3965 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3966 init_test_settings(cx);
3967
3968 let project = create_test_project(cx, json!({})).await;
3969
3970 let (_, _thread_store, thread, _context_store, model) =
3971 setup_test_environment(cx, project.clone()).await;
3972
3973 test_summarize_error(&model, &thread, cx);
3974
3975 // Sending another message should not trigger another summarize request
3976 thread.update(cx, |thread, cx| {
3977 thread.insert_user_message(
3978 "How are you?",
3979 ContextLoadResult::default(),
3980 None,
3981 vec![],
3982 cx,
3983 );
3984 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3985 });
3986
3987 let fake_model = model.as_fake();
3988 simulate_successful_response(&fake_model, cx);
3989
3990 thread.read_with(cx, |thread, _| {
3991 // State is still Error, not Generating
3992 assert!(matches!(thread.summary(), ThreadSummary::Error));
3993 });
3994
3995 // But the summarize request can be invoked manually
3996 thread.update(cx, |thread, cx| {
3997 thread.summarize(cx);
3998 });
3999
4000 thread.read_with(cx, |thread, _| {
4001 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4002 });
4003
4004 cx.run_until_parked();
4005 fake_model.stream_last_completion_response("A successful summary");
4006 fake_model.end_last_completion_stream();
4007 cx.run_until_parked();
4008
4009 thread.read_with(cx, |thread, _| {
4010 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4011 assert_eq!(thread.summary().or_default(), "A successful summary");
4012 });
4013 }
4014
4015 // Helper to create a model that returns errors
4016 enum TestError {
4017 Overloaded,
4018 InternalServerError,
4019 }
4020
4021 struct ErrorInjector {
4022 inner: Arc<FakeLanguageModel>,
4023 error_type: TestError,
4024 }
4025
4026 impl ErrorInjector {
4027 fn new(error_type: TestError) -> Self {
4028 Self {
4029 inner: Arc::new(FakeLanguageModel::default()),
4030 error_type,
4031 }
4032 }
4033 }
4034
4035 impl LanguageModel for ErrorInjector {
4036 fn id(&self) -> LanguageModelId {
4037 self.inner.id()
4038 }
4039
4040 fn name(&self) -> LanguageModelName {
4041 self.inner.name()
4042 }
4043
4044 fn provider_id(&self) -> LanguageModelProviderId {
4045 self.inner.provider_id()
4046 }
4047
4048 fn provider_name(&self) -> LanguageModelProviderName {
4049 self.inner.provider_name()
4050 }
4051
4052 fn supports_tools(&self) -> bool {
4053 self.inner.supports_tools()
4054 }
4055
4056 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4057 self.inner.supports_tool_choice(choice)
4058 }
4059
4060 fn supports_images(&self) -> bool {
4061 self.inner.supports_images()
4062 }
4063
4064 fn telemetry_id(&self) -> String {
4065 self.inner.telemetry_id()
4066 }
4067
4068 fn max_token_count(&self) -> u64 {
4069 self.inner.max_token_count()
4070 }
4071
4072 fn count_tokens(
4073 &self,
4074 request: LanguageModelRequest,
4075 cx: &App,
4076 ) -> BoxFuture<'static, Result<u64>> {
4077 self.inner.count_tokens(request, cx)
4078 }
4079
4080 fn stream_completion(
4081 &self,
4082 _request: LanguageModelRequest,
4083 _cx: &AsyncApp,
4084 ) -> BoxFuture<
4085 'static,
4086 Result<
4087 BoxStream<
4088 'static,
4089 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4090 >,
4091 LanguageModelCompletionError,
4092 >,
4093 > {
4094 let error = match self.error_type {
4095 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4096 provider: self.provider_name(),
4097 retry_after: None,
4098 },
4099 TestError::InternalServerError => {
4100 LanguageModelCompletionError::ApiInternalServerError {
4101 provider: self.provider_name(),
4102 message: "I'm a teapot orbiting the sun".to_string(),
4103 }
4104 }
4105 };
4106 async move {
4107 let stream = futures::stream::once(async move { Err(error) });
4108 Ok(stream.boxed())
4109 }
4110 .boxed()
4111 }
4112
4113 fn as_fake(&self) -> &FakeLanguageModel {
4114 &self.inner
4115 }
4116 }
4117
4118 #[gpui::test]
4119 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4120 init_test_settings(cx);
4121
4122 let project = create_test_project(cx, json!({})).await;
4123 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4124
4125 // Create model that returns overloaded error
4126 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4127
4128 // Insert a user message
4129 thread.update(cx, |thread, cx| {
4130 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4131 });
4132
4133 // Start completion
4134 thread.update(cx, |thread, cx| {
4135 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4136 });
4137
4138 cx.run_until_parked();
4139
4140 thread.read_with(cx, |thread, _| {
4141 assert!(thread.retry_state.is_some(), "Should have retry state");
4142 let retry_state = thread.retry_state.as_ref().unwrap();
4143 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4144 assert_eq!(
4145 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4146 "Should have default max attempts"
4147 );
4148 });
4149
4150 // Check that a retry message was added
4151 thread.read_with(cx, |thread, _| {
4152 let mut messages = thread.messages();
4153 assert!(
4154 messages.any(|msg| {
4155 msg.role == Role::System
4156 && msg.ui_only
4157 && msg.segments.iter().any(|seg| {
4158 if let MessageSegment::Text(text) = seg {
4159 text.contains("overloaded")
4160 && text
4161 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4162 } else {
4163 false
4164 }
4165 })
4166 }),
4167 "Should have added a system retry message"
4168 );
4169 });
4170
4171 let retry_count = thread.update(cx, |thread, _| {
4172 thread
4173 .messages
4174 .iter()
4175 .filter(|m| {
4176 m.ui_only
4177 && m.segments.iter().any(|s| {
4178 if let MessageSegment::Text(text) = s {
4179 text.contains("Retrying") && text.contains("seconds")
4180 } else {
4181 false
4182 }
4183 })
4184 })
4185 .count()
4186 });
4187
4188 assert_eq!(retry_count, 1, "Should have one retry message");
4189 }
4190
4191 #[gpui::test]
4192 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4193 init_test_settings(cx);
4194
4195 let project = create_test_project(cx, json!({})).await;
4196 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4197
4198 // Create model that returns internal server error
4199 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4200
4201 // Insert a user message
4202 thread.update(cx, |thread, cx| {
4203 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4204 });
4205
4206 // Start completion
4207 thread.update(cx, |thread, cx| {
4208 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4209 });
4210
4211 cx.run_until_parked();
4212
4213 // Check retry state on thread
4214 thread.read_with(cx, |thread, _| {
4215 assert!(thread.retry_state.is_some(), "Should have retry state");
4216 let retry_state = thread.retry_state.as_ref().unwrap();
4217 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4218 assert_eq!(
4219 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4220 "Should have correct max attempts"
4221 );
4222 });
4223
4224 // Check that a retry message was added with provider name
4225 thread.read_with(cx, |thread, _| {
4226 let mut messages = thread.messages();
4227 assert!(
4228 messages.any(|msg| {
4229 msg.role == Role::System
4230 && msg.ui_only
4231 && msg.segments.iter().any(|seg| {
4232 if let MessageSegment::Text(text) = seg {
4233 text.contains("internal")
4234 && text.contains("Fake")
4235 && text
4236 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4237 } else {
4238 false
4239 }
4240 })
4241 }),
4242 "Should have added a system retry message with provider name"
4243 );
4244 });
4245
4246 // Count retry messages
4247 let retry_count = thread.update(cx, |thread, _| {
4248 thread
4249 .messages
4250 .iter()
4251 .filter(|m| {
4252 m.ui_only
4253 && m.segments.iter().any(|s| {
4254 if let MessageSegment::Text(text) = s {
4255 text.contains("Retrying") && text.contains("seconds")
4256 } else {
4257 false
4258 }
4259 })
4260 })
4261 .count()
4262 });
4263
4264 assert_eq!(retry_count, 1, "Should have one retry message");
4265 }
4266
4267 #[gpui::test]
4268 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4269 init_test_settings(cx);
4270
4271 let project = create_test_project(cx, json!({})).await;
4272 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4273
4274 // Create model that returns overloaded error
4275 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4276
4277 // Insert a user message
4278 thread.update(cx, |thread, cx| {
4279 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4280 });
4281
4282 // Track retry events and completion count
4283 // Track completion events
4284 let completion_count = Arc::new(Mutex::new(0));
4285 let completion_count_clone = completion_count.clone();
4286
4287 let _subscription = thread.update(cx, |_, cx| {
4288 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4289 if let ThreadEvent::NewRequest = event {
4290 *completion_count_clone.lock() += 1;
4291 }
4292 })
4293 });
4294
4295 // First attempt
4296 thread.update(cx, |thread, cx| {
4297 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4298 });
4299 cx.run_until_parked();
4300
4301 // Should have scheduled first retry - count retry messages
4302 let retry_count = thread.update(cx, |thread, _| {
4303 thread
4304 .messages
4305 .iter()
4306 .filter(|m| {
4307 m.ui_only
4308 && m.segments.iter().any(|s| {
4309 if let MessageSegment::Text(text) = s {
4310 text.contains("Retrying") && text.contains("seconds")
4311 } else {
4312 false
4313 }
4314 })
4315 })
4316 .count()
4317 });
4318 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4319
4320 // Check retry state
4321 thread.read_with(cx, |thread, _| {
4322 assert!(thread.retry_state.is_some(), "Should have retry state");
4323 let retry_state = thread.retry_state.as_ref().unwrap();
4324 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4325 });
4326
4327 // Advance clock for first retry
4328 cx.executor()
4329 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4330 cx.run_until_parked();
4331
4332 // Should have scheduled second retry - count retry messages
4333 let retry_count = thread.update(cx, |thread, _| {
4334 thread
4335 .messages
4336 .iter()
4337 .filter(|m| {
4338 m.ui_only
4339 && m.segments.iter().any(|s| {
4340 if let MessageSegment::Text(text) = s {
4341 text.contains("Retrying") && text.contains("seconds")
4342 } else {
4343 false
4344 }
4345 })
4346 })
4347 .count()
4348 });
4349 assert_eq!(retry_count, 2, "Should have scheduled second retry");
4350
4351 // Check retry state updated
4352 thread.read_with(cx, |thread, _| {
4353 assert!(thread.retry_state.is_some(), "Should have retry state");
4354 let retry_state = thread.retry_state.as_ref().unwrap();
4355 assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4356 assert_eq!(
4357 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4358 "Should have correct max attempts"
4359 );
4360 });
4361
4362 // Advance clock for second retry (exponential backoff)
4363 cx.executor()
4364 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4365 cx.run_until_parked();
4366
4367 // Should have scheduled third retry
4368 // Count all retry messages now
4369 let retry_count = thread.update(cx, |thread, _| {
4370 thread
4371 .messages
4372 .iter()
4373 .filter(|m| {
4374 m.ui_only
4375 && m.segments.iter().any(|s| {
4376 if let MessageSegment::Text(text) = s {
4377 text.contains("Retrying") && text.contains("seconds")
4378 } else {
4379 false
4380 }
4381 })
4382 })
4383 .count()
4384 });
4385 assert_eq!(
4386 retry_count, MAX_RETRY_ATTEMPTS as usize,
4387 "Should have scheduled third retry"
4388 );
4389
4390 // Check retry state updated
4391 thread.read_with(cx, |thread, _| {
4392 assert!(thread.retry_state.is_some(), "Should have retry state");
4393 let retry_state = thread.retry_state.as_ref().unwrap();
4394 assert_eq!(
4395 retry_state.attempt, MAX_RETRY_ATTEMPTS,
4396 "Should be at max retry attempt"
4397 );
4398 assert_eq!(
4399 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4400 "Should have correct max attempts"
4401 );
4402 });
4403
4404 // Advance clock for third retry (exponential backoff)
4405 cx.executor()
4406 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4407 cx.run_until_parked();
4408
4409 // No more retries should be scheduled after clock was advanced.
4410 let retry_count = thread.update(cx, |thread, _| {
4411 thread
4412 .messages
4413 .iter()
4414 .filter(|m| {
4415 m.ui_only
4416 && m.segments.iter().any(|s| {
4417 if let MessageSegment::Text(text) = s {
4418 text.contains("Retrying") && text.contains("seconds")
4419 } else {
4420 false
4421 }
4422 })
4423 })
4424 .count()
4425 });
4426 assert_eq!(
4427 retry_count, MAX_RETRY_ATTEMPTS as usize,
4428 "Should not exceed max retries"
4429 );
4430
4431 // Final completion count should be initial + max retries
4432 assert_eq!(
4433 *completion_count.lock(),
4434 (MAX_RETRY_ATTEMPTS + 1) as usize,
4435 "Should have made initial + max retry attempts"
4436 );
4437 }
4438
4439 #[gpui::test]
4440 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4441 init_test_settings(cx);
4442
4443 let project = create_test_project(cx, json!({})).await;
4444 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4445
4446 // Create model that returns overloaded error
4447 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4448
4449 // Insert a user message
4450 thread.update(cx, |thread, cx| {
4451 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4452 });
4453
4454 // Track events
4455 let retries_failed = Arc::new(Mutex::new(false));
4456 let retries_failed_clone = retries_failed.clone();
4457
4458 let _subscription = thread.update(cx, |_, cx| {
4459 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4460 if let ThreadEvent::RetriesFailed { .. } = event {
4461 *retries_failed_clone.lock() = true;
4462 }
4463 })
4464 });
4465
4466 // Start initial completion
4467 thread.update(cx, |thread, cx| {
4468 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4469 });
4470 cx.run_until_parked();
4471
4472 // Advance through all retries
4473 for i in 0..MAX_RETRY_ATTEMPTS {
4474 let delay = if i == 0 {
4475 BASE_RETRY_DELAY_SECS
4476 } else {
4477 BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4478 };
4479 cx.executor().advance_clock(Duration::from_secs(delay));
4480 cx.run_until_parked();
4481 }
4482
4483 // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4484 // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4485 let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4486 cx.executor()
4487 .advance_clock(Duration::from_secs(final_delay));
4488 cx.run_until_parked();
4489
4490 let retry_count = thread.update(cx, |thread, _| {
4491 thread
4492 .messages
4493 .iter()
4494 .filter(|m| {
4495 m.ui_only
4496 && m.segments.iter().any(|s| {
4497 if let MessageSegment::Text(text) = s {
4498 text.contains("Retrying") && text.contains("seconds")
4499 } else {
4500 false
4501 }
4502 })
4503 })
4504 .count()
4505 });
4506
4507 // After max retries, should emit RetriesFailed event
4508 assert_eq!(
4509 retry_count, MAX_RETRY_ATTEMPTS as usize,
4510 "Should have attempted max retries"
4511 );
4512 assert!(
4513 *retries_failed.lock(),
4514 "Should emit RetriesFailed event after max retries exceeded"
4515 );
4516
4517 // Retry state should be cleared
4518 thread.read_with(cx, |thread, _| {
4519 assert!(
4520 thread.retry_state.is_none(),
4521 "Retry state should be cleared after max retries"
4522 );
4523
4524 // Verify we have the expected number of retry messages
4525 let retry_messages = thread
4526 .messages
4527 .iter()
4528 .filter(|msg| msg.ui_only && msg.role == Role::System)
4529 .count();
4530 assert_eq!(
4531 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4532 "Should have one retry message per attempt"
4533 );
4534 });
4535 }
4536
4537 #[gpui::test]
4538 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4539 init_test_settings(cx);
4540
4541 let project = create_test_project(cx, json!({})).await;
4542 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4543
4544 // We'll use a wrapper to switch behavior after first failure
4545 struct RetryTestModel {
4546 inner: Arc<FakeLanguageModel>,
4547 failed_once: Arc<Mutex<bool>>,
4548 }
4549
4550 impl LanguageModel for RetryTestModel {
4551 fn id(&self) -> LanguageModelId {
4552 self.inner.id()
4553 }
4554
4555 fn name(&self) -> LanguageModelName {
4556 self.inner.name()
4557 }
4558
4559 fn provider_id(&self) -> LanguageModelProviderId {
4560 self.inner.provider_id()
4561 }
4562
4563 fn provider_name(&self) -> LanguageModelProviderName {
4564 self.inner.provider_name()
4565 }
4566
4567 fn supports_tools(&self) -> bool {
4568 self.inner.supports_tools()
4569 }
4570
4571 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4572 self.inner.supports_tool_choice(choice)
4573 }
4574
4575 fn supports_images(&self) -> bool {
4576 self.inner.supports_images()
4577 }
4578
4579 fn telemetry_id(&self) -> String {
4580 self.inner.telemetry_id()
4581 }
4582
4583 fn max_token_count(&self) -> u64 {
4584 self.inner.max_token_count()
4585 }
4586
4587 fn count_tokens(
4588 &self,
4589 request: LanguageModelRequest,
4590 cx: &App,
4591 ) -> BoxFuture<'static, Result<u64>> {
4592 self.inner.count_tokens(request, cx)
4593 }
4594
4595 fn stream_completion(
4596 &self,
4597 request: LanguageModelRequest,
4598 cx: &AsyncApp,
4599 ) -> BoxFuture<
4600 'static,
4601 Result<
4602 BoxStream<
4603 'static,
4604 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4605 >,
4606 LanguageModelCompletionError,
4607 >,
4608 > {
4609 if !*self.failed_once.lock() {
4610 *self.failed_once.lock() = true;
4611 let provider = self.provider_name();
4612 // Return error on first attempt
4613 let stream = futures::stream::once(async move {
4614 Err(LanguageModelCompletionError::ServerOverloaded {
4615 provider,
4616 retry_after: None,
4617 })
4618 });
4619 async move { Ok(stream.boxed()) }.boxed()
4620 } else {
4621 // Succeed on retry
4622 self.inner.stream_completion(request, cx)
4623 }
4624 }
4625
4626 fn as_fake(&self) -> &FakeLanguageModel {
4627 &self.inner
4628 }
4629 }
4630
4631 let model = Arc::new(RetryTestModel {
4632 inner: Arc::new(FakeLanguageModel::default()),
4633 failed_once: Arc::new(Mutex::new(false)),
4634 });
4635
4636 // Insert a user message
4637 thread.update(cx, |thread, cx| {
4638 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4639 });
4640
4641 // Track message deletions
4642 // Track when retry completes successfully
4643 let retry_completed = Arc::new(Mutex::new(false));
4644 let retry_completed_clone = retry_completed.clone();
4645
4646 let _subscription = thread.update(cx, |_, cx| {
4647 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4648 if let ThreadEvent::StreamedCompletion = event {
4649 *retry_completed_clone.lock() = true;
4650 }
4651 })
4652 });
4653
4654 // Start completion
4655 thread.update(cx, |thread, cx| {
4656 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4657 });
4658 cx.run_until_parked();
4659
4660 // Get the retry message ID
4661 let retry_message_id = thread.read_with(cx, |thread, _| {
4662 thread
4663 .messages()
4664 .find(|msg| msg.role == Role::System && msg.ui_only)
4665 .map(|msg| msg.id)
4666 .expect("Should have a retry message")
4667 });
4668
4669 // Wait for retry
4670 cx.executor()
4671 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4672 cx.run_until_parked();
4673
4674 // Stream some successful content
4675 let fake_model = model.as_fake();
4676 // After the retry, there should be a new pending completion
4677 let pending = fake_model.pending_completions();
4678 assert!(
4679 !pending.is_empty(),
4680 "Should have a pending completion after retry"
4681 );
4682 fake_model.stream_completion_response(&pending[0], "Success!");
4683 fake_model.end_completion_stream(&pending[0]);
4684 cx.run_until_parked();
4685
4686 // Check that the retry completed successfully
4687 assert!(
4688 *retry_completed.lock(),
4689 "Retry should have completed successfully"
4690 );
4691
4692 // Retry message should still exist but be marked as ui_only
4693 thread.read_with(cx, |thread, _| {
4694 let retry_msg = thread
4695 .message(retry_message_id)
4696 .expect("Retry message should still exist");
4697 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4698 assert_eq!(
4699 retry_msg.role,
4700 Role::System,
4701 "Retry message should have System role"
4702 );
4703 });
4704 }
4705
4706 #[gpui::test]
4707 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4708 init_test_settings(cx);
4709
4710 let project = create_test_project(cx, json!({})).await;
4711 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4712
4713 // Create a model that fails once then succeeds
4714 struct FailOnceModel {
4715 inner: Arc<FakeLanguageModel>,
4716 failed_once: Arc<Mutex<bool>>,
4717 }
4718
4719 impl LanguageModel for FailOnceModel {
4720 fn id(&self) -> LanguageModelId {
4721 self.inner.id()
4722 }
4723
4724 fn name(&self) -> LanguageModelName {
4725 self.inner.name()
4726 }
4727
4728 fn provider_id(&self) -> LanguageModelProviderId {
4729 self.inner.provider_id()
4730 }
4731
4732 fn provider_name(&self) -> LanguageModelProviderName {
4733 self.inner.provider_name()
4734 }
4735
4736 fn supports_tools(&self) -> bool {
4737 self.inner.supports_tools()
4738 }
4739
4740 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4741 self.inner.supports_tool_choice(choice)
4742 }
4743
4744 fn supports_images(&self) -> bool {
4745 self.inner.supports_images()
4746 }
4747
4748 fn telemetry_id(&self) -> String {
4749 self.inner.telemetry_id()
4750 }
4751
4752 fn max_token_count(&self) -> u64 {
4753 self.inner.max_token_count()
4754 }
4755
4756 fn count_tokens(
4757 &self,
4758 request: LanguageModelRequest,
4759 cx: &App,
4760 ) -> BoxFuture<'static, Result<u64>> {
4761 self.inner.count_tokens(request, cx)
4762 }
4763
4764 fn stream_completion(
4765 &self,
4766 request: LanguageModelRequest,
4767 cx: &AsyncApp,
4768 ) -> BoxFuture<
4769 'static,
4770 Result<
4771 BoxStream<
4772 'static,
4773 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4774 >,
4775 LanguageModelCompletionError,
4776 >,
4777 > {
4778 if !*self.failed_once.lock() {
4779 *self.failed_once.lock() = true;
4780 let provider = self.provider_name();
4781 // Return error on first attempt
4782 let stream = futures::stream::once(async move {
4783 Err(LanguageModelCompletionError::ServerOverloaded {
4784 provider,
4785 retry_after: None,
4786 })
4787 });
4788 async move { Ok(stream.boxed()) }.boxed()
4789 } else {
4790 // Succeed on retry
4791 self.inner.stream_completion(request, cx)
4792 }
4793 }
4794 }
4795
4796 let fail_once_model = Arc::new(FailOnceModel {
4797 inner: Arc::new(FakeLanguageModel::default()),
4798 failed_once: Arc::new(Mutex::new(false)),
4799 });
4800
4801 // Insert a user message
4802 thread.update(cx, |thread, cx| {
4803 thread.insert_user_message(
4804 "Test message",
4805 ContextLoadResult::default(),
4806 None,
4807 vec![],
4808 cx,
4809 );
4810 });
4811
4812 // Start completion with fail-once model
4813 thread.update(cx, |thread, cx| {
4814 thread.send_to_model(
4815 fail_once_model.clone(),
4816 CompletionIntent::UserPrompt,
4817 None,
4818 cx,
4819 );
4820 });
4821
4822 cx.run_until_parked();
4823
4824 // Verify retry state exists after first failure
4825 thread.read_with(cx, |thread, _| {
4826 assert!(
4827 thread.retry_state.is_some(),
4828 "Should have retry state after failure"
4829 );
4830 });
4831
4832 // Wait for retry delay
4833 cx.executor()
4834 .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4835 cx.run_until_parked();
4836
4837 // The retry should now use our FailOnceModel which should succeed
4838 // We need to help the FakeLanguageModel complete the stream
4839 let inner_fake = fail_once_model.inner.clone();
4840
4841 // Wait a bit for the retry to start
4842 cx.run_until_parked();
4843
4844 // Check for pending completions and complete them
4845 if let Some(pending) = inner_fake.pending_completions().first() {
4846 inner_fake.stream_completion_response(pending, "Success!");
4847 inner_fake.end_completion_stream(pending);
4848 }
4849 cx.run_until_parked();
4850
4851 thread.read_with(cx, |thread, _| {
4852 assert!(
4853 thread.retry_state.is_none(),
4854 "Retry state should be cleared after successful completion"
4855 );
4856
4857 let has_assistant_message = thread
4858 .messages
4859 .iter()
4860 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4861 assert!(
4862 has_assistant_message,
4863 "Should have an assistant message after successful retry"
4864 );
4865 });
4866 }
4867
4868 #[gpui::test]
4869 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4870 init_test_settings(cx);
4871
4872 let project = create_test_project(cx, json!({})).await;
4873 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4874
4875 // Create a model that returns rate limit error with retry_after
4876 struct RateLimitModel {
4877 inner: Arc<FakeLanguageModel>,
4878 }
4879
4880 impl LanguageModel for RateLimitModel {
4881 fn id(&self) -> LanguageModelId {
4882 self.inner.id()
4883 }
4884
4885 fn name(&self) -> LanguageModelName {
4886 self.inner.name()
4887 }
4888
4889 fn provider_id(&self) -> LanguageModelProviderId {
4890 self.inner.provider_id()
4891 }
4892
4893 fn provider_name(&self) -> LanguageModelProviderName {
4894 self.inner.provider_name()
4895 }
4896
4897 fn supports_tools(&self) -> bool {
4898 self.inner.supports_tools()
4899 }
4900
4901 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4902 self.inner.supports_tool_choice(choice)
4903 }
4904
4905 fn supports_images(&self) -> bool {
4906 self.inner.supports_images()
4907 }
4908
4909 fn telemetry_id(&self) -> String {
4910 self.inner.telemetry_id()
4911 }
4912
4913 fn max_token_count(&self) -> u64 {
4914 self.inner.max_token_count()
4915 }
4916
4917 fn count_tokens(
4918 &self,
4919 request: LanguageModelRequest,
4920 cx: &App,
4921 ) -> BoxFuture<'static, Result<u64>> {
4922 self.inner.count_tokens(request, cx)
4923 }
4924
4925 fn stream_completion(
4926 &self,
4927 _request: LanguageModelRequest,
4928 _cx: &AsyncApp,
4929 ) -> BoxFuture<
4930 'static,
4931 Result<
4932 BoxStream<
4933 'static,
4934 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4935 >,
4936 LanguageModelCompletionError,
4937 >,
4938 > {
4939 let provider = self.provider_name();
4940 async move {
4941 let stream = futures::stream::once(async move {
4942 Err(LanguageModelCompletionError::RateLimitExceeded {
4943 provider,
4944 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4945 })
4946 });
4947 Ok(stream.boxed())
4948 }
4949 .boxed()
4950 }
4951
4952 fn as_fake(&self) -> &FakeLanguageModel {
4953 &self.inner
4954 }
4955 }
4956
4957 let model = Arc::new(RateLimitModel {
4958 inner: Arc::new(FakeLanguageModel::default()),
4959 });
4960
4961 // Insert a user message
4962 thread.update(cx, |thread, cx| {
4963 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4964 });
4965
4966 // Start completion
4967 thread.update(cx, |thread, cx| {
4968 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4969 });
4970
4971 cx.run_until_parked();
4972
4973 let retry_count = thread.update(cx, |thread, _| {
4974 thread
4975 .messages
4976 .iter()
4977 .filter(|m| {
4978 m.ui_only
4979 && m.segments.iter().any(|s| {
4980 if let MessageSegment::Text(text) = s {
4981 text.contains("rate limit exceeded")
4982 } else {
4983 false
4984 }
4985 })
4986 })
4987 .count()
4988 });
4989 assert_eq!(retry_count, 1, "Should have scheduled one retry");
4990
4991 thread.read_with(cx, |thread, _| {
4992 assert!(
4993 thread.retry_state.is_none(),
4994 "Rate limit errors should not set retry_state"
4995 );
4996 });
4997
4998 // Verify we have one retry message
4999 thread.read_with(cx, |thread, _| {
5000 let retry_messages = thread
5001 .messages
5002 .iter()
5003 .filter(|msg| {
5004 msg.ui_only
5005 && msg.segments.iter().any(|seg| {
5006 if let MessageSegment::Text(text) = seg {
5007 text.contains("rate limit exceeded")
5008 } else {
5009 false
5010 }
5011 })
5012 })
5013 .count();
5014 assert_eq!(
5015 retry_messages, 1,
5016 "Should have one rate limit retry message"
5017 );
5018 });
5019
5020 // Check that retry message doesn't include attempt count
5021 thread.read_with(cx, |thread, _| {
5022 let retry_message = thread
5023 .messages
5024 .iter()
5025 .find(|msg| msg.role == Role::System && msg.ui_only)
5026 .expect("Should have a retry message");
5027
5028 // Check that the message doesn't contain attempt count
5029 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5030 assert!(
5031 !text.contains("attempt"),
5032 "Rate limit retry message should not contain attempt count"
5033 );
5034 assert!(
5035 text.contains(&format!(
5036 "Retrying in {} seconds",
5037 TEST_RATE_LIMIT_RETRY_SECS
5038 )),
5039 "Rate limit retry message should contain retry delay"
5040 );
5041 }
5042 });
5043 }
5044
5045 #[gpui::test]
5046 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5047 init_test_settings(cx);
5048
5049 let project = create_test_project(cx, json!({})).await;
5050 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5051
5052 // Insert a regular user message
5053 thread.update(cx, |thread, cx| {
5054 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5055 });
5056
5057 // Insert a UI-only message (like our retry notifications)
5058 thread.update(cx, |thread, cx| {
5059 let id = thread.next_message_id.post_inc();
5060 thread.messages.push(Message {
5061 id,
5062 role: Role::System,
5063 segments: vec![MessageSegment::Text(
5064 "This is a UI-only message that should not be sent to the model".to_string(),
5065 )],
5066 loaded_context: LoadedContext::default(),
5067 creases: Vec::new(),
5068 is_hidden: true,
5069 ui_only: true,
5070 });
5071 cx.emit(ThreadEvent::MessageAdded(id));
5072 });
5073
5074 // Insert another regular message
5075 thread.update(cx, |thread, cx| {
5076 thread.insert_user_message(
5077 "How are you?",
5078 ContextLoadResult::default(),
5079 None,
5080 vec![],
5081 cx,
5082 );
5083 });
5084
5085 // Generate the completion request
5086 let request = thread.update(cx, |thread, cx| {
5087 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5088 });
5089
5090 // Verify that the request only contains non-UI-only messages
5091 // Should have system prompt + 2 user messages, but not the UI-only message
5092 let user_messages: Vec<_> = request
5093 .messages
5094 .iter()
5095 .filter(|msg| msg.role == Role::User)
5096 .collect();
5097 assert_eq!(
5098 user_messages.len(),
5099 2,
5100 "Should have exactly 2 user messages"
5101 );
5102
5103 // Verify the UI-only content is not present anywhere in the request
5104 let request_text = request
5105 .messages
5106 .iter()
5107 .flat_map(|msg| &msg.content)
5108 .filter_map(|content| match content {
5109 MessageContent::Text(text) => Some(text.as_str()),
5110 _ => None,
5111 })
5112 .collect::<String>();
5113
5114 assert!(
5115 !request_text.contains("UI-only message"),
5116 "UI-only message content should not be in the request"
5117 );
5118
5119 // Verify the thread still has all 3 messages (including UI-only)
5120 thread.read_with(cx, |thread, _| {
5121 assert_eq!(
5122 thread.messages().count(),
5123 3,
5124 "Thread should have 3 messages"
5125 );
5126 assert_eq!(
5127 thread.messages().filter(|m| m.ui_only).count(),
5128 1,
5129 "Thread should have 1 UI-only message"
5130 );
5131 });
5132
5133 // Verify that UI-only messages are not serialized
5134 let serialized = thread
5135 .update(cx, |thread, cx| thread.serialize(cx))
5136 .await
5137 .unwrap();
5138 assert_eq!(
5139 serialized.messages.len(),
5140 2,
5141 "Serialized thread should only have 2 messages (no UI-only)"
5142 );
5143 }
5144
5145 #[gpui::test]
5146 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5147 init_test_settings(cx);
5148
5149 let project = create_test_project(cx, json!({})).await;
5150 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5151
5152 // Create model that returns overloaded error
5153 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5154
5155 // Insert a user message
5156 thread.update(cx, |thread, cx| {
5157 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5158 });
5159
5160 // Start completion
5161 thread.update(cx, |thread, cx| {
5162 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5163 });
5164
5165 cx.run_until_parked();
5166
5167 // Verify retry was scheduled by checking for retry message
5168 let has_retry_message = thread.read_with(cx, |thread, _| {
5169 thread.messages.iter().any(|m| {
5170 m.ui_only
5171 && m.segments.iter().any(|s| {
5172 if let MessageSegment::Text(text) = s {
5173 text.contains("Retrying") && text.contains("seconds")
5174 } else {
5175 false
5176 }
5177 })
5178 })
5179 });
5180 assert!(has_retry_message, "Should have scheduled a retry");
5181
5182 // Cancel the completion before the retry happens
5183 thread.update(cx, |thread, cx| {
5184 thread.cancel_last_completion(None, cx);
5185 });
5186
5187 cx.run_until_parked();
5188
5189 // The retry should not have happened - no pending completions
5190 let fake_model = model.as_fake();
5191 assert_eq!(
5192 fake_model.pending_completions().len(),
5193 0,
5194 "Should have no pending completions after cancellation"
5195 );
5196
5197 // Verify the retry was cancelled by checking retry state
5198 thread.read_with(cx, |thread, _| {
5199 if let Some(retry_state) = &thread.retry_state {
5200 panic!(
5201 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5202 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5203 );
5204 }
5205 });
5206 }
5207
5208 fn test_summarize_error(
5209 model: &Arc<dyn LanguageModel>,
5210 thread: &Entity<Thread>,
5211 cx: &mut TestAppContext,
5212 ) {
5213 thread.update(cx, |thread, cx| {
5214 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5215 thread.send_to_model(
5216 model.clone(),
5217 CompletionIntent::ThreadSummarization,
5218 None,
5219 cx,
5220 );
5221 });
5222
5223 let fake_model = model.as_fake();
5224 simulate_successful_response(&fake_model, cx);
5225
5226 thread.read_with(cx, |thread, _| {
5227 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5228 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5229 });
5230
5231 // Simulate summary request ending
5232 cx.run_until_parked();
5233 fake_model.end_last_completion_stream();
5234 cx.run_until_parked();
5235
5236 // State is set to Error and default message
5237 thread.read_with(cx, |thread, _| {
5238 assert!(matches!(thread.summary(), ThreadSummary::Error));
5239 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5240 });
5241 }
5242
5243 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5244 cx.run_until_parked();
5245 fake_model.stream_last_completion_response("Assistant response");
5246 fake_model.end_last_completion_stream();
5247 cx.run_until_parked();
5248 }
5249
5250 fn init_test_settings(cx: &mut TestAppContext) {
5251 cx.update(|cx| {
5252 let settings_store = SettingsStore::test(cx);
5253 cx.set_global(settings_store);
5254 language::init(cx);
5255 Project::init_settings(cx);
5256 AgentSettings::register(cx);
5257 prompt_store::init(cx);
5258 thread_store::init(cx);
5259 workspace::init_settings(cx);
5260 language_model::init_settings(cx);
5261 ThemeSettings::register(cx);
5262 ToolRegistry::default_global(cx);
5263 assistant_tool::init(cx);
5264
5265 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5266 http_client::FakeHttpClient::with_200_response(),
5267 "http://localhost".to_string(),
5268 None,
5269 ));
5270 assistant_tools::init(http_client, cx);
5271 });
5272 }
5273
5274 // Helper to create a test project with test files
5275 async fn create_test_project(
5276 cx: &mut TestAppContext,
5277 files: serde_json::Value,
5278 ) -> Entity<Project> {
5279 let fs = FakeFs::new(cx.executor());
5280 fs.insert_tree(path!("/test"), files).await;
5281 Project::test(fs, [path!("/test").as_ref()], cx).await
5282 }
5283
5284 async fn setup_test_environment(
5285 cx: &mut TestAppContext,
5286 project: Entity<Project>,
5287 ) -> (
5288 Entity<Workspace>,
5289 Entity<ThreadStore>,
5290 Entity<Thread>,
5291 Entity<ContextStore>,
5292 Arc<dyn LanguageModel>,
5293 ) {
5294 let (workspace, cx) =
5295 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5296
5297 let thread_store = cx
5298 .update(|_, cx| {
5299 ThreadStore::load(
5300 project.clone(),
5301 cx.new(|_| ToolWorkingSet::default()),
5302 None,
5303 Arc::new(PromptBuilder::new(None).unwrap()),
5304 cx,
5305 )
5306 })
5307 .await
5308 .unwrap();
5309
5310 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5311 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5312
5313 let provider = Arc::new(FakeLanguageModelProvider);
5314 let model = provider.test_model();
5315 let model: Arc<dyn LanguageModel> = Arc::new(model);
5316
5317 cx.update(|_, cx| {
5318 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5319 registry.set_default_model(
5320 Some(ConfiguredModel {
5321 provider: provider.clone(),
5322 model: model.clone(),
5323 }),
5324 cx,
5325 );
5326 registry.set_thread_summary_model(
5327 Some(ConfiguredModel {
5328 provider,
5329 model: model.clone(),
5330 }),
5331 cx,
5332 );
5333 })
5334 });
5335
5336 (workspace, thread_store, thread, context_store, model)
5337 }
5338
5339 async fn add_file_to_context(
5340 project: &Entity<Project>,
5341 context_store: &Entity<ContextStore>,
5342 path: &str,
5343 cx: &mut TestAppContext,
5344 ) -> Result<Entity<language::Buffer>> {
5345 let buffer_path = project
5346 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5347 .unwrap();
5348
5349 let buffer = project
5350 .update(cx, |project, cx| {
5351 project.open_buffer(buffer_path.clone(), cx)
5352 })
5353 .await
5354 .unwrap();
5355
5356 context_store.update(cx, |context_store, cx| {
5357 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5358 });
5359
5360 Ok(buffer)
5361 }
5362}