1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::HashMap;
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use http_client::StatusCode;
25use language_model::{
26 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
27 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
28 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
29 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
30 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
31 TokenUsage,
32};
33use postage::stream::Stream as _;
34use project::{
35 Project,
36 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
37};
38use prompt_store::{ModelContext, PromptBuilder};
39use proto::Plan;
40use schemars::JsonSchema;
41use serde::{Deserialize, Serialize};
42use settings::Settings;
43use std::{
44 io::Write,
45 ops::Range,
46 sync::Arc,
47 time::{Duration, Instant},
48};
49use thiserror::Error;
50use util::{ResultExt as _, debug_panic, post_inc};
51use uuid::Uuid;
52use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
53
54const MAX_RETRY_ATTEMPTS: u8 = 3;
55const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
56
57#[derive(Debug, Clone)]
58enum RetryStrategy {
59 ExponentialBackoff {
60 initial_delay: Duration,
61 max_attempts: u8,
62 },
63 Fixed {
64 delay: Duration,
65 max_attempts: u8,
66 },
67}
68
69#[derive(
70 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
71)]
72pub struct ThreadId(Arc<str>);
73
74impl ThreadId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for ThreadId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86impl From<&str> for ThreadId {
87 fn from(value: &str) -> Self {
88 Self(value.into())
89 }
90}
91
92/// The ID of the user prompt that initiated a request.
93///
94/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
95#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
96pub struct PromptId(Arc<str>);
97
98impl PromptId {
99 pub fn new() -> Self {
100 Self(Uuid::new_v4().to_string().into())
101 }
102}
103
104impl std::fmt::Display for PromptId {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 write!(f, "{}", self.0)
107 }
108}
109
110#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
111pub struct MessageId(pub(crate) usize);
112
113impl MessageId {
114 fn post_inc(&mut self) -> Self {
115 Self(post_inc(&mut self.0))
116 }
117
118 pub fn as_usize(&self) -> usize {
119 self.0
120 }
121}
122
123/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
124#[derive(Clone, Debug)]
125pub struct MessageCrease {
126 pub range: Range<usize>,
127 pub icon_path: SharedString,
128 pub label: SharedString,
129 /// None for a deserialized message, Some otherwise.
130 pub context: Option<AgentContextHandle>,
131}
132
133/// A message in a [`Thread`].
134#[derive(Debug, Clone)]
135pub struct Message {
136 pub id: MessageId,
137 pub role: Role,
138 pub segments: Vec<MessageSegment>,
139 pub loaded_context: LoadedContext,
140 pub creases: Vec<MessageCrease>,
141 pub is_hidden: bool,
142 pub ui_only: bool,
143}
144
145impl Message {
146 /// Returns whether the message contains any meaningful text that should be displayed
147 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
148 pub fn should_display_content(&self) -> bool {
149 self.segments.iter().all(|segment| segment.should_display())
150 }
151
152 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
153 if let Some(MessageSegment::Thinking {
154 text: segment,
155 signature: current_signature,
156 }) = self.segments.last_mut()
157 {
158 if let Some(signature) = signature {
159 *current_signature = Some(signature);
160 }
161 segment.push_str(text);
162 } else {
163 self.segments.push(MessageSegment::Thinking {
164 text: text.to_string(),
165 signature,
166 });
167 }
168 }
169
170 pub fn push_redacted_thinking(&mut self, data: String) {
171 self.segments.push(MessageSegment::RedactedThinking(data));
172 }
173
174 pub fn push_text(&mut self, text: &str) {
175 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
176 segment.push_str(text);
177 } else {
178 self.segments.push(MessageSegment::Text(text.to_string()));
179 }
180 }
181
182 pub fn to_string(&self) -> String {
183 let mut result = String::new();
184
185 if !self.loaded_context.text.is_empty() {
186 result.push_str(&self.loaded_context.text);
187 }
188
189 for segment in &self.segments {
190 match segment {
191 MessageSegment::Text(text) => result.push_str(text),
192 MessageSegment::Thinking { text, .. } => {
193 result.push_str("<think>\n");
194 result.push_str(text);
195 result.push_str("\n</think>");
196 }
197 MessageSegment::RedactedThinking(_) => {}
198 }
199 }
200
201 result
202 }
203}
204
205#[derive(Debug, Clone, PartialEq, Eq)]
206pub enum MessageSegment {
207 Text(String),
208 Thinking {
209 text: String,
210 signature: Option<String>,
211 },
212 RedactedThinking(String),
213}
214
215impl MessageSegment {
216 pub fn should_display(&self) -> bool {
217 match self {
218 Self::Text(text) => text.is_empty(),
219 Self::Thinking { text, .. } => text.is_empty(),
220 Self::RedactedThinking(_) => false,
221 }
222 }
223
224 pub fn text(&self) -> Option<&str> {
225 match self {
226 MessageSegment::Text(text) => Some(text),
227 _ => None,
228 }
229 }
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
233pub struct ProjectSnapshot {
234 pub worktree_snapshots: Vec<WorktreeSnapshot>,
235 pub unsaved_buffer_paths: Vec<String>,
236 pub timestamp: DateTime<Utc>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
240pub struct WorktreeSnapshot {
241 pub worktree_path: String,
242 pub git_state: Option<GitState>,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
246pub struct GitState {
247 pub remote_url: Option<String>,
248 pub head_sha: Option<String>,
249 pub current_branch: Option<String>,
250 pub diff: Option<String>,
251}
252
253#[derive(Clone, Debug)]
254pub struct ThreadCheckpoint {
255 message_id: MessageId,
256 git_checkpoint: GitStoreCheckpoint,
257}
258
259#[derive(Copy, Clone, Debug, PartialEq, Eq)]
260pub enum ThreadFeedback {
261 Positive,
262 Negative,
263}
264
265pub enum LastRestoreCheckpoint {
266 Pending {
267 message_id: MessageId,
268 },
269 Error {
270 message_id: MessageId,
271 error: String,
272 },
273}
274
275impl LastRestoreCheckpoint {
276 pub fn message_id(&self) -> MessageId {
277 match self {
278 LastRestoreCheckpoint::Pending { message_id } => *message_id,
279 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
280 }
281 }
282}
283
284#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
285pub enum DetailedSummaryState {
286 #[default]
287 NotGenerated,
288 Generating {
289 message_id: MessageId,
290 },
291 Generated {
292 text: SharedString,
293 message_id: MessageId,
294 },
295}
296
297impl DetailedSummaryState {
298 fn text(&self) -> Option<SharedString> {
299 if let Self::Generated { text, .. } = self {
300 Some(text.clone())
301 } else {
302 None
303 }
304 }
305}
306
307#[derive(Default, Debug)]
308pub struct TotalTokenUsage {
309 pub total: u64,
310 pub max: u64,
311}
312
313impl TotalTokenUsage {
314 pub fn ratio(&self) -> TokenUsageRatio {
315 #[cfg(debug_assertions)]
316 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
317 .unwrap_or("0.8".to_string())
318 .parse()
319 .unwrap();
320 #[cfg(not(debug_assertions))]
321 let warning_threshold: f32 = 0.8;
322
323 // When the maximum is unknown because there is no selected model,
324 // avoid showing the token limit warning.
325 if self.max == 0 {
326 TokenUsageRatio::Normal
327 } else if self.total >= self.max {
328 TokenUsageRatio::Exceeded
329 } else if self.total as f32 / self.max as f32 >= warning_threshold {
330 TokenUsageRatio::Warning
331 } else {
332 TokenUsageRatio::Normal
333 }
334 }
335
336 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
337 TotalTokenUsage {
338 total: self.total + tokens,
339 max: self.max,
340 }
341 }
342}
343
344#[derive(Debug, Default, PartialEq, Eq)]
345pub enum TokenUsageRatio {
346 #[default]
347 Normal,
348 Warning,
349 Exceeded,
350}
351
352#[derive(Debug, Clone, Copy)]
353pub enum QueueState {
354 Sending,
355 Queued { position: usize },
356 Started,
357}
358
359/// A thread of conversation with the LLM.
360pub struct Thread {
361 id: ThreadId,
362 updated_at: DateTime<Utc>,
363 summary: ThreadSummary,
364 pending_summary: Task<Option<()>>,
365 detailed_summary_task: Task<Option<()>>,
366 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
367 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
368 completion_mode: agent_settings::CompletionMode,
369 messages: Vec<Message>,
370 next_message_id: MessageId,
371 last_prompt_id: PromptId,
372 project_context: SharedProjectContext,
373 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
374 completion_count: usize,
375 pending_completions: Vec<PendingCompletion>,
376 project: Entity<Project>,
377 prompt_builder: Arc<PromptBuilder>,
378 tools: Entity<ToolWorkingSet>,
379 tool_use: ToolUseState,
380 action_log: Entity<ActionLog>,
381 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
382 pending_checkpoint: Option<ThreadCheckpoint>,
383 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
384 request_token_usage: Vec<TokenUsage>,
385 cumulative_token_usage: TokenUsage,
386 exceeded_window_error: Option<ExceededWindowError>,
387 tool_use_limit_reached: bool,
388 feedback: Option<ThreadFeedback>,
389 retry_state: Option<RetryState>,
390 message_feedback: HashMap<MessageId, ThreadFeedback>,
391 last_auto_capture_at: Option<Instant>,
392 last_received_chunk_at: Option<Instant>,
393 request_callback: Option<
394 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
395 >,
396 remaining_turns: u32,
397 configured_model: Option<ConfiguredModel>,
398 profile: AgentProfile,
399}
400
401#[derive(Clone, Debug)]
402struct RetryState {
403 attempt: u8,
404 max_attempts: u8,
405 intent: CompletionIntent,
406}
407
408#[derive(Clone, Debug, PartialEq, Eq)]
409pub enum ThreadSummary {
410 Pending,
411 Generating,
412 Ready(SharedString),
413 Error,
414}
415
416impl ThreadSummary {
417 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
418
419 pub fn or_default(&self) -> SharedString {
420 self.unwrap_or(Self::DEFAULT)
421 }
422
423 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
424 self.ready().unwrap_or_else(|| message.into())
425 }
426
427 pub fn ready(&self) -> Option<SharedString> {
428 match self {
429 ThreadSummary::Ready(summary) => Some(summary.clone()),
430 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
436pub struct ExceededWindowError {
437 /// Model used when last message exceeded context window
438 model_id: LanguageModelId,
439 /// Token count including last message
440 token_count: u64,
441}
442
443impl Thread {
444 pub fn new(
445 project: Entity<Project>,
446 tools: Entity<ToolWorkingSet>,
447 prompt_builder: Arc<PromptBuilder>,
448 system_prompt: SharedProjectContext,
449 cx: &mut Context<Self>,
450 ) -> Self {
451 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
452 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
453 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
454
455 Self {
456 id: ThreadId::new(),
457 updated_at: Utc::now(),
458 summary: ThreadSummary::Pending,
459 pending_summary: Task::ready(None),
460 detailed_summary_task: Task::ready(None),
461 detailed_summary_tx,
462 detailed_summary_rx,
463 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
464 messages: Vec::new(),
465 next_message_id: MessageId(0),
466 last_prompt_id: PromptId::new(),
467 project_context: system_prompt,
468 checkpoints_by_message: HashMap::default(),
469 completion_count: 0,
470 pending_completions: Vec::new(),
471 project: project.clone(),
472 prompt_builder,
473 tools: tools.clone(),
474 last_restore_checkpoint: None,
475 pending_checkpoint: None,
476 tool_use: ToolUseState::new(tools.clone()),
477 action_log: cx.new(|_| ActionLog::new(project.clone())),
478 initial_project_snapshot: {
479 let project_snapshot = Self::project_snapshot(project, cx);
480 cx.foreground_executor()
481 .spawn(async move { Some(project_snapshot.await) })
482 .shared()
483 },
484 request_token_usage: Vec::new(),
485 cumulative_token_usage: TokenUsage::default(),
486 exceeded_window_error: None,
487 tool_use_limit_reached: false,
488 feedback: None,
489 retry_state: None,
490 message_feedback: HashMap::default(),
491 last_auto_capture_at: None,
492 last_received_chunk_at: None,
493 request_callback: None,
494 remaining_turns: u32::MAX,
495 configured_model,
496 profile: AgentProfile::new(profile_id, tools),
497 }
498 }
499
500 pub fn deserialize(
501 id: ThreadId,
502 serialized: SerializedThread,
503 project: Entity<Project>,
504 tools: Entity<ToolWorkingSet>,
505 prompt_builder: Arc<PromptBuilder>,
506 project_context: SharedProjectContext,
507 window: Option<&mut Window>, // None in headless mode
508 cx: &mut Context<Self>,
509 ) -> Self {
510 let next_message_id = MessageId(
511 serialized
512 .messages
513 .last()
514 .map(|message| message.id.0 + 1)
515 .unwrap_or(0),
516 );
517 let tool_use = ToolUseState::from_serialized_messages(
518 tools.clone(),
519 &serialized.messages,
520 project.clone(),
521 window,
522 cx,
523 );
524 let (detailed_summary_tx, detailed_summary_rx) =
525 postage::watch::channel_with(serialized.detailed_summary_state);
526
527 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
528 serialized
529 .model
530 .and_then(|model| {
531 let model = SelectedModel {
532 provider: model.provider.clone().into(),
533 model: model.model.clone().into(),
534 };
535 registry.select_model(&model, cx)
536 })
537 .or_else(|| registry.default_model())
538 });
539
540 let completion_mode = serialized
541 .completion_mode
542 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
543 let profile_id = serialized
544 .profile
545 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
546
547 Self {
548 id,
549 updated_at: serialized.updated_at,
550 summary: ThreadSummary::Ready(serialized.summary),
551 pending_summary: Task::ready(None),
552 detailed_summary_task: Task::ready(None),
553 detailed_summary_tx,
554 detailed_summary_rx,
555 completion_mode,
556 retry_state: None,
557 messages: serialized
558 .messages
559 .into_iter()
560 .map(|message| Message {
561 id: message.id,
562 role: message.role,
563 segments: message
564 .segments
565 .into_iter()
566 .map(|segment| match segment {
567 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
568 SerializedMessageSegment::Thinking { text, signature } => {
569 MessageSegment::Thinking { text, signature }
570 }
571 SerializedMessageSegment::RedactedThinking { data } => {
572 MessageSegment::RedactedThinking(data)
573 }
574 })
575 .collect(),
576 loaded_context: LoadedContext {
577 contexts: Vec::new(),
578 text: message.context,
579 images: Vec::new(),
580 },
581 creases: message
582 .creases
583 .into_iter()
584 .map(|crease| MessageCrease {
585 range: crease.start..crease.end,
586 icon_path: crease.icon_path,
587 label: crease.label,
588 context: None,
589 })
590 .collect(),
591 is_hidden: message.is_hidden,
592 ui_only: false, // UI-only messages are not persisted
593 })
594 .collect(),
595 next_message_id,
596 last_prompt_id: PromptId::new(),
597 project_context,
598 checkpoints_by_message: HashMap::default(),
599 completion_count: 0,
600 pending_completions: Vec::new(),
601 last_restore_checkpoint: None,
602 pending_checkpoint: None,
603 project: project.clone(),
604 prompt_builder,
605 tools: tools.clone(),
606 tool_use,
607 action_log: cx.new(|_| ActionLog::new(project)),
608 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
609 request_token_usage: serialized.request_token_usage,
610 cumulative_token_usage: serialized.cumulative_token_usage,
611 exceeded_window_error: None,
612 tool_use_limit_reached: serialized.tool_use_limit_reached,
613 feedback: None,
614 message_feedback: HashMap::default(),
615 last_auto_capture_at: None,
616 last_received_chunk_at: None,
617 request_callback: None,
618 remaining_turns: u32::MAX,
619 configured_model,
620 profile: AgentProfile::new(profile_id, tools),
621 }
622 }
623
624 pub fn set_request_callback(
625 &mut self,
626 callback: impl 'static
627 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
628 ) {
629 self.request_callback = Some(Box::new(callback));
630 }
631
632 pub fn id(&self) -> &ThreadId {
633 &self.id
634 }
635
636 pub fn profile(&self) -> &AgentProfile {
637 &self.profile
638 }
639
640 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
641 if &id != self.profile.id() {
642 self.profile = AgentProfile::new(id, self.tools.clone());
643 cx.emit(ThreadEvent::ProfileChanged);
644 }
645 }
646
647 pub fn is_empty(&self) -> bool {
648 self.messages.is_empty()
649 }
650
651 pub fn updated_at(&self) -> DateTime<Utc> {
652 self.updated_at
653 }
654
655 pub fn touch_updated_at(&mut self) {
656 self.updated_at = Utc::now();
657 }
658
659 pub fn advance_prompt_id(&mut self) {
660 self.last_prompt_id = PromptId::new();
661 }
662
663 pub fn project_context(&self) -> SharedProjectContext {
664 self.project_context.clone()
665 }
666
667 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
668 if self.configured_model.is_none() {
669 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
670 }
671 self.configured_model.clone()
672 }
673
674 pub fn configured_model(&self) -> Option<ConfiguredModel> {
675 self.configured_model.clone()
676 }
677
678 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
679 self.configured_model = model;
680 cx.notify();
681 }
682
683 pub fn summary(&self) -> &ThreadSummary {
684 &self.summary
685 }
686
687 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
688 let current_summary = match &self.summary {
689 ThreadSummary::Pending | ThreadSummary::Generating => return,
690 ThreadSummary::Ready(summary) => summary,
691 ThreadSummary::Error => &ThreadSummary::DEFAULT,
692 };
693
694 let mut new_summary = new_summary.into();
695
696 if new_summary.is_empty() {
697 new_summary = ThreadSummary::DEFAULT;
698 }
699
700 if current_summary != &new_summary {
701 self.summary = ThreadSummary::Ready(new_summary);
702 cx.emit(ThreadEvent::SummaryChanged);
703 }
704 }
705
706 pub fn completion_mode(&self) -> CompletionMode {
707 self.completion_mode
708 }
709
710 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
711 self.completion_mode = mode;
712 }
713
714 pub fn message(&self, id: MessageId) -> Option<&Message> {
715 let index = self
716 .messages
717 .binary_search_by(|message| message.id.cmp(&id))
718 .ok()?;
719
720 self.messages.get(index)
721 }
722
723 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
724 self.messages.iter()
725 }
726
727 pub fn is_generating(&self) -> bool {
728 !self.pending_completions.is_empty() || !self.all_tools_finished()
729 }
730
731 /// Indicates whether streaming of language model events is stale.
732 /// When `is_generating()` is false, this method returns `None`.
733 pub fn is_generation_stale(&self) -> Option<bool> {
734 const STALE_THRESHOLD: u128 = 250;
735
736 self.last_received_chunk_at
737 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
738 }
739
740 fn received_chunk(&mut self) {
741 self.last_received_chunk_at = Some(Instant::now());
742 }
743
744 pub fn queue_state(&self) -> Option<QueueState> {
745 self.pending_completions
746 .first()
747 .map(|pending_completion| pending_completion.queue_state)
748 }
749
750 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
751 &self.tools
752 }
753
754 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
755 self.tool_use
756 .pending_tool_uses()
757 .into_iter()
758 .find(|tool_use| &tool_use.id == id)
759 }
760
761 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
762 self.tool_use
763 .pending_tool_uses()
764 .into_iter()
765 .filter(|tool_use| tool_use.status.needs_confirmation())
766 }
767
768 pub fn has_pending_tool_uses(&self) -> bool {
769 !self.tool_use.pending_tool_uses().is_empty()
770 }
771
772 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
773 self.checkpoints_by_message.get(&id).cloned()
774 }
775
776 pub fn restore_checkpoint(
777 &mut self,
778 checkpoint: ThreadCheckpoint,
779 cx: &mut Context<Self>,
780 ) -> Task<Result<()>> {
781 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
782 message_id: checkpoint.message_id,
783 });
784 cx.emit(ThreadEvent::CheckpointChanged);
785 cx.notify();
786
787 let git_store = self.project().read(cx).git_store().clone();
788 let restore = git_store.update(cx, |git_store, cx| {
789 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
790 });
791
792 cx.spawn(async move |this, cx| {
793 let result = restore.await;
794 this.update(cx, |this, cx| {
795 if let Err(err) = result.as_ref() {
796 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
797 message_id: checkpoint.message_id,
798 error: err.to_string(),
799 });
800 } else {
801 this.truncate(checkpoint.message_id, cx);
802 this.last_restore_checkpoint = None;
803 }
804 this.pending_checkpoint = None;
805 cx.emit(ThreadEvent::CheckpointChanged);
806 cx.notify();
807 })?;
808 result
809 })
810 }
811
812 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
813 let pending_checkpoint = if self.is_generating() {
814 return;
815 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
816 checkpoint
817 } else {
818 return;
819 };
820
821 self.finalize_checkpoint(pending_checkpoint, cx);
822 }
823
824 fn finalize_checkpoint(
825 &mut self,
826 pending_checkpoint: ThreadCheckpoint,
827 cx: &mut Context<Self>,
828 ) {
829 let git_store = self.project.read(cx).git_store().clone();
830 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
831 cx.spawn(async move |this, cx| match final_checkpoint.await {
832 Ok(final_checkpoint) => {
833 let equal = git_store
834 .update(cx, |store, cx| {
835 store.compare_checkpoints(
836 pending_checkpoint.git_checkpoint.clone(),
837 final_checkpoint.clone(),
838 cx,
839 )
840 })?
841 .await
842 .unwrap_or(false);
843
844 if !equal {
845 this.update(cx, |this, cx| {
846 this.insert_checkpoint(pending_checkpoint, cx)
847 })?;
848 }
849
850 Ok(())
851 }
852 Err(_) => this.update(cx, |this, cx| {
853 this.insert_checkpoint(pending_checkpoint, cx)
854 }),
855 })
856 .detach();
857 }
858
859 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
860 self.checkpoints_by_message
861 .insert(checkpoint.message_id, checkpoint);
862 cx.emit(ThreadEvent::CheckpointChanged);
863 cx.notify();
864 }
865
866 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
867 self.last_restore_checkpoint.as_ref()
868 }
869
870 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
871 let Some(message_ix) = self
872 .messages
873 .iter()
874 .rposition(|message| message.id == message_id)
875 else {
876 return;
877 };
878 for deleted_message in self.messages.drain(message_ix..) {
879 self.checkpoints_by_message.remove(&deleted_message.id);
880 }
881 cx.notify();
882 }
883
884 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
885 self.messages
886 .iter()
887 .find(|message| message.id == id)
888 .into_iter()
889 .flat_map(|message| message.loaded_context.contexts.iter())
890 }
891
892 pub fn is_turn_end(&self, ix: usize) -> bool {
893 if self.messages.is_empty() {
894 return false;
895 }
896
897 if !self.is_generating() && ix == self.messages.len() - 1 {
898 return true;
899 }
900
901 let Some(message) = self.messages.get(ix) else {
902 return false;
903 };
904
905 if message.role != Role::Assistant {
906 return false;
907 }
908
909 self.messages
910 .get(ix + 1)
911 .and_then(|message| {
912 self.message(message.id)
913 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
914 })
915 .unwrap_or(false)
916 }
917
918 pub fn tool_use_limit_reached(&self) -> bool {
919 self.tool_use_limit_reached
920 }
921
922 /// Returns whether all of the tool uses have finished running.
923 pub fn all_tools_finished(&self) -> bool {
924 // If the only pending tool uses left are the ones with errors, then
925 // that means that we've finished running all of the pending tools.
926 self.tool_use
927 .pending_tool_uses()
928 .iter()
929 .all(|pending_tool_use| pending_tool_use.status.is_error())
930 }
931
932 /// Returns whether any pending tool uses may perform edits
933 pub fn has_pending_edit_tool_uses(&self) -> bool {
934 self.tool_use
935 .pending_tool_uses()
936 .iter()
937 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
938 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
939 }
940
941 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
942 self.tool_use.tool_uses_for_message(id, cx)
943 }
944
945 pub fn tool_results_for_message(
946 &self,
947 assistant_message_id: MessageId,
948 ) -> Vec<&LanguageModelToolResult> {
949 self.tool_use.tool_results_for_message(assistant_message_id)
950 }
951
952 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
953 self.tool_use.tool_result(id)
954 }
955
956 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
957 match &self.tool_use.tool_result(id)?.content {
958 LanguageModelToolResultContent::Text(text) => Some(text),
959 LanguageModelToolResultContent::Image(_) => {
960 // TODO: We should display image
961 None
962 }
963 }
964 }
965
966 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
967 self.tool_use.tool_result_card(id).cloned()
968 }
969
970 /// Return tools that are both enabled and supported by the model
971 pub fn available_tools(
972 &self,
973 cx: &App,
974 model: Arc<dyn LanguageModel>,
975 ) -> Vec<LanguageModelRequestTool> {
976 if model.supports_tools() {
977 self.profile
978 .enabled_tools(cx)
979 .into_iter()
980 .filter_map(|(name, tool)| {
981 // Skip tools that cannot be supported
982 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
983 Some(LanguageModelRequestTool {
984 name: name.into(),
985 description: tool.description(),
986 input_schema,
987 })
988 })
989 .collect()
990 } else {
991 Vec::default()
992 }
993 }
994
995 pub fn insert_user_message(
996 &mut self,
997 text: impl Into<String>,
998 loaded_context: ContextLoadResult,
999 git_checkpoint: Option<GitStoreCheckpoint>,
1000 creases: Vec<MessageCrease>,
1001 cx: &mut Context<Self>,
1002 ) -> MessageId {
1003 if !loaded_context.referenced_buffers.is_empty() {
1004 self.action_log.update(cx, |log, cx| {
1005 for buffer in loaded_context.referenced_buffers {
1006 log.buffer_read(buffer, cx);
1007 }
1008 });
1009 }
1010
1011 let message_id = self.insert_message(
1012 Role::User,
1013 vec![MessageSegment::Text(text.into())],
1014 loaded_context.loaded_context,
1015 creases,
1016 false,
1017 cx,
1018 );
1019
1020 if let Some(git_checkpoint) = git_checkpoint {
1021 self.pending_checkpoint = Some(ThreadCheckpoint {
1022 message_id,
1023 git_checkpoint,
1024 });
1025 }
1026
1027 self.auto_capture_telemetry(cx);
1028
1029 message_id
1030 }
1031
1032 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1033 let id = self.insert_message(
1034 Role::User,
1035 vec![MessageSegment::Text("Continue where you left off".into())],
1036 LoadedContext::default(),
1037 vec![],
1038 true,
1039 cx,
1040 );
1041 self.pending_checkpoint = None;
1042
1043 id
1044 }
1045
1046 pub fn insert_assistant_message(
1047 &mut self,
1048 segments: Vec<MessageSegment>,
1049 cx: &mut Context<Self>,
1050 ) -> MessageId {
1051 self.insert_message(
1052 Role::Assistant,
1053 segments,
1054 LoadedContext::default(),
1055 Vec::new(),
1056 false,
1057 cx,
1058 )
1059 }
1060
1061 pub fn insert_message(
1062 &mut self,
1063 role: Role,
1064 segments: Vec<MessageSegment>,
1065 loaded_context: LoadedContext,
1066 creases: Vec<MessageCrease>,
1067 is_hidden: bool,
1068 cx: &mut Context<Self>,
1069 ) -> MessageId {
1070 let id = self.next_message_id.post_inc();
1071 self.messages.push(Message {
1072 id,
1073 role,
1074 segments,
1075 loaded_context,
1076 creases,
1077 is_hidden,
1078 ui_only: false,
1079 });
1080 self.touch_updated_at();
1081 cx.emit(ThreadEvent::MessageAdded(id));
1082 id
1083 }
1084
1085 pub fn edit_message(
1086 &mut self,
1087 id: MessageId,
1088 new_role: Role,
1089 new_segments: Vec<MessageSegment>,
1090 creases: Vec<MessageCrease>,
1091 loaded_context: Option<LoadedContext>,
1092 checkpoint: Option<GitStoreCheckpoint>,
1093 cx: &mut Context<Self>,
1094 ) -> bool {
1095 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1096 return false;
1097 };
1098 message.role = new_role;
1099 message.segments = new_segments;
1100 message.creases = creases;
1101 if let Some(context) = loaded_context {
1102 message.loaded_context = context;
1103 }
1104 if let Some(git_checkpoint) = checkpoint {
1105 self.checkpoints_by_message.insert(
1106 id,
1107 ThreadCheckpoint {
1108 message_id: id,
1109 git_checkpoint,
1110 },
1111 );
1112 }
1113 self.touch_updated_at();
1114 cx.emit(ThreadEvent::MessageEdited(id));
1115 true
1116 }
1117
1118 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1119 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1120 return false;
1121 };
1122 self.messages.remove(index);
1123 self.touch_updated_at();
1124 cx.emit(ThreadEvent::MessageDeleted(id));
1125 true
1126 }
1127
1128 /// Returns the representation of this [`Thread`] in a textual form.
1129 ///
1130 /// This is the representation we use when attaching a thread as context to another thread.
1131 pub fn text(&self) -> String {
1132 let mut text = String::new();
1133
1134 for message in &self.messages {
1135 text.push_str(match message.role {
1136 language_model::Role::User => "User:",
1137 language_model::Role::Assistant => "Agent:",
1138 language_model::Role::System => "System:",
1139 });
1140 text.push('\n');
1141
1142 for segment in &message.segments {
1143 match segment {
1144 MessageSegment::Text(content) => text.push_str(content),
1145 MessageSegment::Thinking { text: content, .. } => {
1146 text.push_str(&format!("<think>{}</think>", content))
1147 }
1148 MessageSegment::RedactedThinking(_) => {}
1149 }
1150 }
1151 text.push('\n');
1152 }
1153
1154 text
1155 }
1156
1157 /// Serializes this thread into a format for storage or telemetry.
1158 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1159 let initial_project_snapshot = self.initial_project_snapshot.clone();
1160 cx.spawn(async move |this, cx| {
1161 let initial_project_snapshot = initial_project_snapshot.await;
1162 this.read_with(cx, |this, cx| SerializedThread {
1163 version: SerializedThread::VERSION.to_string(),
1164 summary: this.summary().or_default(),
1165 updated_at: this.updated_at(),
1166 messages: this
1167 .messages()
1168 .filter(|message| !message.ui_only)
1169 .map(|message| SerializedMessage {
1170 id: message.id,
1171 role: message.role,
1172 segments: message
1173 .segments
1174 .iter()
1175 .map(|segment| match segment {
1176 MessageSegment::Text(text) => {
1177 SerializedMessageSegment::Text { text: text.clone() }
1178 }
1179 MessageSegment::Thinking { text, signature } => {
1180 SerializedMessageSegment::Thinking {
1181 text: text.clone(),
1182 signature: signature.clone(),
1183 }
1184 }
1185 MessageSegment::RedactedThinking(data) => {
1186 SerializedMessageSegment::RedactedThinking {
1187 data: data.clone(),
1188 }
1189 }
1190 })
1191 .collect(),
1192 tool_uses: this
1193 .tool_uses_for_message(message.id, cx)
1194 .into_iter()
1195 .map(|tool_use| SerializedToolUse {
1196 id: tool_use.id,
1197 name: tool_use.name,
1198 input: tool_use.input,
1199 })
1200 .collect(),
1201 tool_results: this
1202 .tool_results_for_message(message.id)
1203 .into_iter()
1204 .map(|tool_result| SerializedToolResult {
1205 tool_use_id: tool_result.tool_use_id.clone(),
1206 is_error: tool_result.is_error,
1207 content: tool_result.content.clone(),
1208 output: tool_result.output.clone(),
1209 })
1210 .collect(),
1211 context: message.loaded_context.text.clone(),
1212 creases: message
1213 .creases
1214 .iter()
1215 .map(|crease| SerializedCrease {
1216 start: crease.range.start,
1217 end: crease.range.end,
1218 icon_path: crease.icon_path.clone(),
1219 label: crease.label.clone(),
1220 })
1221 .collect(),
1222 is_hidden: message.is_hidden,
1223 })
1224 .collect(),
1225 initial_project_snapshot,
1226 cumulative_token_usage: this.cumulative_token_usage,
1227 request_token_usage: this.request_token_usage.clone(),
1228 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1229 exceeded_window_error: this.exceeded_window_error.clone(),
1230 model: this
1231 .configured_model
1232 .as_ref()
1233 .map(|model| SerializedLanguageModel {
1234 provider: model.provider.id().0.to_string(),
1235 model: model.model.id().0.to_string(),
1236 }),
1237 completion_mode: Some(this.completion_mode),
1238 tool_use_limit_reached: this.tool_use_limit_reached,
1239 profile: Some(this.profile.id().clone()),
1240 })
1241 })
1242 }
1243
1244 pub fn remaining_turns(&self) -> u32 {
1245 self.remaining_turns
1246 }
1247
1248 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1249 self.remaining_turns = remaining_turns;
1250 }
1251
1252 pub fn send_to_model(
1253 &mut self,
1254 model: Arc<dyn LanguageModel>,
1255 intent: CompletionIntent,
1256 window: Option<AnyWindowHandle>,
1257 cx: &mut Context<Self>,
1258 ) {
1259 if self.remaining_turns == 0 {
1260 return;
1261 }
1262
1263 self.remaining_turns -= 1;
1264
1265 self.flush_notifications(model.clone(), intent, cx);
1266
1267 let request = self.to_completion_request(model.clone(), intent, cx);
1268
1269 self.stream_completion(request, model, intent, window, cx);
1270 }
1271
1272 pub fn used_tools_since_last_user_message(&self) -> bool {
1273 for message in self.messages.iter().rev() {
1274 if self.tool_use.message_has_tool_results(message.id) {
1275 return true;
1276 } else if message.role == Role::User {
1277 return false;
1278 }
1279 }
1280
1281 false
1282 }
1283
1284 pub fn to_completion_request(
1285 &self,
1286 model: Arc<dyn LanguageModel>,
1287 intent: CompletionIntent,
1288 cx: &mut Context<Self>,
1289 ) -> LanguageModelRequest {
1290 let mut request = LanguageModelRequest {
1291 thread_id: Some(self.id.to_string()),
1292 prompt_id: Some(self.last_prompt_id.to_string()),
1293 intent: Some(intent),
1294 mode: None,
1295 messages: vec![],
1296 tools: Vec::new(),
1297 tool_choice: None,
1298 stop: Vec::new(),
1299 temperature: AgentSettings::temperature_for_model(&model, cx),
1300 };
1301
1302 let available_tools = self.available_tools(cx, model.clone());
1303 let available_tool_names = available_tools
1304 .iter()
1305 .map(|tool| tool.name.clone())
1306 .collect();
1307
1308 let model_context = &ModelContext {
1309 available_tools: available_tool_names,
1310 };
1311
1312 if let Some(project_context) = self.project_context.borrow().as_ref() {
1313 match self
1314 .prompt_builder
1315 .generate_assistant_system_prompt(project_context, model_context)
1316 {
1317 Err(err) => {
1318 let message = format!("{err:?}").into();
1319 log::error!("{message}");
1320 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1321 header: "Error generating system prompt".into(),
1322 message,
1323 }));
1324 }
1325 Ok(system_prompt) => {
1326 request.messages.push(LanguageModelRequestMessage {
1327 role: Role::System,
1328 content: vec![MessageContent::Text(system_prompt)],
1329 cache: true,
1330 });
1331 }
1332 }
1333 } else {
1334 let message = "Context for system prompt unexpectedly not ready.".into();
1335 log::error!("{message}");
1336 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1337 header: "Error generating system prompt".into(),
1338 message,
1339 }));
1340 }
1341
1342 let mut message_ix_to_cache = None;
1343 for message in &self.messages {
1344 // ui_only messages are for the UI only, not for the model
1345 if message.ui_only {
1346 continue;
1347 }
1348
1349 let mut request_message = LanguageModelRequestMessage {
1350 role: message.role,
1351 content: Vec::new(),
1352 cache: false,
1353 };
1354
1355 message
1356 .loaded_context
1357 .add_to_request_message(&mut request_message);
1358
1359 for segment in &message.segments {
1360 match segment {
1361 MessageSegment::Text(text) => {
1362 let text = text.trim_end();
1363 if !text.is_empty() {
1364 request_message
1365 .content
1366 .push(MessageContent::Text(text.into()));
1367 }
1368 }
1369 MessageSegment::Thinking { text, signature } => {
1370 if !text.is_empty() {
1371 request_message.content.push(MessageContent::Thinking {
1372 text: text.into(),
1373 signature: signature.clone(),
1374 });
1375 }
1376 }
1377 MessageSegment::RedactedThinking(data) => {
1378 request_message
1379 .content
1380 .push(MessageContent::RedactedThinking(data.clone()));
1381 }
1382 };
1383 }
1384
1385 let mut cache_message = true;
1386 let mut tool_results_message = LanguageModelRequestMessage {
1387 role: Role::User,
1388 content: Vec::new(),
1389 cache: false,
1390 };
1391 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1392 if let Some(tool_result) = tool_result {
1393 request_message
1394 .content
1395 .push(MessageContent::ToolUse(tool_use.clone()));
1396 tool_results_message
1397 .content
1398 .push(MessageContent::ToolResult(LanguageModelToolResult {
1399 tool_use_id: tool_use.id.clone(),
1400 tool_name: tool_result.tool_name.clone(),
1401 is_error: tool_result.is_error,
1402 content: if tool_result.content.is_empty() {
1403 // Surprisingly, the API fails if we return an empty string here.
1404 // It thinks we are sending a tool use without a tool result.
1405 "<Tool returned an empty string>".into()
1406 } else {
1407 tool_result.content.clone()
1408 },
1409 output: None,
1410 }));
1411 } else {
1412 cache_message = false;
1413 log::debug!(
1414 "skipped tool use {:?} because it is still pending",
1415 tool_use
1416 );
1417 }
1418 }
1419
1420 if cache_message {
1421 message_ix_to_cache = Some(request.messages.len());
1422 }
1423 request.messages.push(request_message);
1424
1425 if !tool_results_message.content.is_empty() {
1426 if cache_message {
1427 message_ix_to_cache = Some(request.messages.len());
1428 }
1429 request.messages.push(tool_results_message);
1430 }
1431 }
1432
1433 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1434 if let Some(message_ix_to_cache) = message_ix_to_cache {
1435 request.messages[message_ix_to_cache].cache = true;
1436 }
1437
1438 request.tools = available_tools;
1439 request.mode = if model.supports_burn_mode() {
1440 Some(self.completion_mode.into())
1441 } else {
1442 Some(CompletionMode::Normal.into())
1443 };
1444
1445 request
1446 }
1447
1448 fn to_summarize_request(
1449 &self,
1450 model: &Arc<dyn LanguageModel>,
1451 intent: CompletionIntent,
1452 added_user_message: String,
1453 cx: &App,
1454 ) -> LanguageModelRequest {
1455 let mut request = LanguageModelRequest {
1456 thread_id: None,
1457 prompt_id: None,
1458 intent: Some(intent),
1459 mode: None,
1460 messages: vec![],
1461 tools: Vec::new(),
1462 tool_choice: None,
1463 stop: Vec::new(),
1464 temperature: AgentSettings::temperature_for_model(model, cx),
1465 };
1466
1467 for message in &self.messages {
1468 let mut request_message = LanguageModelRequestMessage {
1469 role: message.role,
1470 content: Vec::new(),
1471 cache: false,
1472 };
1473
1474 for segment in &message.segments {
1475 match segment {
1476 MessageSegment::Text(text) => request_message
1477 .content
1478 .push(MessageContent::Text(text.clone())),
1479 MessageSegment::Thinking { .. } => {}
1480 MessageSegment::RedactedThinking(_) => {}
1481 }
1482 }
1483
1484 if request_message.content.is_empty() {
1485 continue;
1486 }
1487
1488 request.messages.push(request_message);
1489 }
1490
1491 request.messages.push(LanguageModelRequestMessage {
1492 role: Role::User,
1493 content: vec![MessageContent::Text(added_user_message)],
1494 cache: false,
1495 });
1496
1497 request
1498 }
1499
1500 /// Insert auto-generated notifications (if any) to the thread
1501 fn flush_notifications(
1502 &mut self,
1503 model: Arc<dyn LanguageModel>,
1504 intent: CompletionIntent,
1505 cx: &mut Context<Self>,
1506 ) {
1507 match intent {
1508 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1509 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1510 cx.emit(ThreadEvent::ToolFinished {
1511 tool_use_id: pending_tool_use.id.clone(),
1512 pending_tool_use: Some(pending_tool_use),
1513 });
1514 }
1515 }
1516 CompletionIntent::ThreadSummarization
1517 | CompletionIntent::ThreadContextSummarization
1518 | CompletionIntent::CreateFile
1519 | CompletionIntent::EditFile
1520 | CompletionIntent::InlineAssist
1521 | CompletionIntent::TerminalInlineAssist
1522 | CompletionIntent::GenerateGitCommitMessage => {}
1523 };
1524 }
1525
1526 fn attach_tracked_files_state(
1527 &mut self,
1528 model: Arc<dyn LanguageModel>,
1529 cx: &mut App,
1530 ) -> Option<PendingToolUse> {
1531 let action_log = self.action_log.read(cx);
1532
1533 action_log.unnotified_stale_buffers(cx).next()?;
1534
1535 // Represent notification as a simulated `project_notifications` tool call
1536 let tool_name = Arc::from("project_notifications");
1537 let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1538 debug_panic!("`project_notifications` tool not found");
1539 return None;
1540 };
1541
1542 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1543 return None;
1544 }
1545
1546 let input = serde_json::json!({});
1547 let request = Arc::new(LanguageModelRequest::default()); // unused
1548 let window = None;
1549 let tool_result = tool.run(
1550 input,
1551 request,
1552 self.project.clone(),
1553 self.action_log.clone(),
1554 model.clone(),
1555 window,
1556 cx,
1557 );
1558
1559 let tool_use_id =
1560 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1561
1562 let tool_use = LanguageModelToolUse {
1563 id: tool_use_id.clone(),
1564 name: tool_name.clone(),
1565 raw_input: "{}".to_string(),
1566 input: serde_json::json!({}),
1567 is_input_complete: true,
1568 };
1569
1570 let tool_output = cx.background_executor().block(tool_result.output);
1571
1572 // Attach a project_notification tool call to the latest existing
1573 // Assistant message. We cannot create a new Assistant message
1574 // because thinking models require a `thinking` block that we
1575 // cannot mock. We cannot send a notification as a normal
1576 // (non-tool-use) User message because this distracts Agent
1577 // too much.
1578 let tool_message_id = self
1579 .messages
1580 .iter()
1581 .enumerate()
1582 .rfind(|(_, message)| message.role == Role::Assistant)
1583 .map(|(_, message)| message.id)?;
1584
1585 let tool_use_metadata = ToolUseMetadata {
1586 model: model.clone(),
1587 thread_id: self.id.clone(),
1588 prompt_id: self.last_prompt_id.clone(),
1589 };
1590
1591 self.tool_use
1592 .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1593
1594 let pending_tool_use = self.tool_use.insert_tool_output(
1595 tool_use_id.clone(),
1596 tool_name,
1597 tool_output,
1598 self.configured_model.as_ref(),
1599 self.completion_mode,
1600 );
1601
1602 pending_tool_use
1603 }
1604
1605 pub fn stream_completion(
1606 &mut self,
1607 request: LanguageModelRequest,
1608 model: Arc<dyn LanguageModel>,
1609 intent: CompletionIntent,
1610 window: Option<AnyWindowHandle>,
1611 cx: &mut Context<Self>,
1612 ) {
1613 self.tool_use_limit_reached = false;
1614
1615 let pending_completion_id = post_inc(&mut self.completion_count);
1616 let mut request_callback_parameters = if self.request_callback.is_some() {
1617 Some((request.clone(), Vec::new()))
1618 } else {
1619 None
1620 };
1621 let prompt_id = self.last_prompt_id.clone();
1622 let tool_use_metadata = ToolUseMetadata {
1623 model: model.clone(),
1624 thread_id: self.id.clone(),
1625 prompt_id: prompt_id.clone(),
1626 };
1627
1628 let completion_mode = request
1629 .mode
1630 .unwrap_or(zed_llm_client::CompletionMode::Normal);
1631
1632 self.last_received_chunk_at = Some(Instant::now());
1633
1634 let task = cx.spawn(async move |thread, cx| {
1635 let stream_completion_future = model.stream_completion(request, &cx);
1636 let initial_token_usage =
1637 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1638 let stream_completion = async {
1639 let mut events = stream_completion_future.await?;
1640
1641 let mut stop_reason = StopReason::EndTurn;
1642 let mut current_token_usage = TokenUsage::default();
1643
1644 thread
1645 .update(cx, |_thread, cx| {
1646 cx.emit(ThreadEvent::NewRequest);
1647 })
1648 .ok();
1649
1650 let mut request_assistant_message_id = None;
1651
1652 while let Some(event) = events.next().await {
1653 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1654 response_events
1655 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1656 }
1657
1658 thread.update(cx, |thread, cx| {
1659 match event? {
1660 LanguageModelCompletionEvent::StartMessage { .. } => {
1661 request_assistant_message_id =
1662 Some(thread.insert_assistant_message(
1663 vec![MessageSegment::Text(String::new())],
1664 cx,
1665 ));
1666 }
1667 LanguageModelCompletionEvent::Stop(reason) => {
1668 stop_reason = reason;
1669 }
1670 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1671 thread.update_token_usage_at_last_message(token_usage);
1672 thread.cumulative_token_usage = thread.cumulative_token_usage
1673 + token_usage
1674 - current_token_usage;
1675 current_token_usage = token_usage;
1676 }
1677 LanguageModelCompletionEvent::Text(chunk) => {
1678 thread.received_chunk();
1679
1680 cx.emit(ThreadEvent::ReceivedTextChunk);
1681 if let Some(last_message) = thread.messages.last_mut() {
1682 if last_message.role == Role::Assistant
1683 && !thread.tool_use.has_tool_results(last_message.id)
1684 {
1685 last_message.push_text(&chunk);
1686 cx.emit(ThreadEvent::StreamedAssistantText(
1687 last_message.id,
1688 chunk,
1689 ));
1690 } else {
1691 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1692 // of a new Assistant response.
1693 //
1694 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1695 // will result in duplicating the text of the chunk in the rendered Markdown.
1696 request_assistant_message_id =
1697 Some(thread.insert_assistant_message(
1698 vec![MessageSegment::Text(chunk.to_string())],
1699 cx,
1700 ));
1701 };
1702 }
1703 }
1704 LanguageModelCompletionEvent::Thinking {
1705 text: chunk,
1706 signature,
1707 } => {
1708 thread.received_chunk();
1709
1710 if let Some(last_message) = thread.messages.last_mut() {
1711 if last_message.role == Role::Assistant
1712 && !thread.tool_use.has_tool_results(last_message.id)
1713 {
1714 last_message.push_thinking(&chunk, signature);
1715 cx.emit(ThreadEvent::StreamedAssistantThinking(
1716 last_message.id,
1717 chunk,
1718 ));
1719 } else {
1720 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1721 // of a new Assistant response.
1722 //
1723 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1724 // will result in duplicating the text of the chunk in the rendered Markdown.
1725 request_assistant_message_id =
1726 Some(thread.insert_assistant_message(
1727 vec![MessageSegment::Thinking {
1728 text: chunk.to_string(),
1729 signature,
1730 }],
1731 cx,
1732 ));
1733 };
1734 }
1735 }
1736 LanguageModelCompletionEvent::RedactedThinking { data } => {
1737 thread.received_chunk();
1738
1739 if let Some(last_message) = thread.messages.last_mut() {
1740 if last_message.role == Role::Assistant
1741 && !thread.tool_use.has_tool_results(last_message.id)
1742 {
1743 last_message.push_redacted_thinking(data);
1744 } else {
1745 request_assistant_message_id =
1746 Some(thread.insert_assistant_message(
1747 vec![MessageSegment::RedactedThinking(data)],
1748 cx,
1749 ));
1750 };
1751 }
1752 }
1753 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1754 let last_assistant_message_id = request_assistant_message_id
1755 .unwrap_or_else(|| {
1756 let new_assistant_message_id =
1757 thread.insert_assistant_message(vec![], cx);
1758 request_assistant_message_id =
1759 Some(new_assistant_message_id);
1760 new_assistant_message_id
1761 });
1762
1763 let tool_use_id = tool_use.id.clone();
1764 let streamed_input = if tool_use.is_input_complete {
1765 None
1766 } else {
1767 Some((&tool_use.input).clone())
1768 };
1769
1770 let ui_text = thread.tool_use.request_tool_use(
1771 last_assistant_message_id,
1772 tool_use,
1773 tool_use_metadata.clone(),
1774 cx,
1775 );
1776
1777 if let Some(input) = streamed_input {
1778 cx.emit(ThreadEvent::StreamedToolUse {
1779 tool_use_id,
1780 ui_text,
1781 input,
1782 });
1783 }
1784 }
1785 LanguageModelCompletionEvent::ToolUseJsonParseError {
1786 id,
1787 tool_name,
1788 raw_input: invalid_input_json,
1789 json_parse_error,
1790 } => {
1791 thread.receive_invalid_tool_json(
1792 id,
1793 tool_name,
1794 invalid_input_json,
1795 json_parse_error,
1796 window,
1797 cx,
1798 );
1799 }
1800 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1801 if let Some(completion) = thread
1802 .pending_completions
1803 .iter_mut()
1804 .find(|completion| completion.id == pending_completion_id)
1805 {
1806 match status_update {
1807 CompletionRequestStatus::Queued { position } => {
1808 completion.queue_state =
1809 QueueState::Queued { position };
1810 }
1811 CompletionRequestStatus::Started => {
1812 completion.queue_state = QueueState::Started;
1813 }
1814 CompletionRequestStatus::Failed {
1815 code,
1816 message,
1817 request_id: _,
1818 retry_after,
1819 } => {
1820 return Err(
1821 LanguageModelCompletionError::from_cloud_failure(
1822 model.upstream_provider_name(),
1823 code,
1824 message,
1825 retry_after.map(Duration::from_secs_f64),
1826 ),
1827 );
1828 }
1829 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1830 thread.update_model_request_usage(
1831 amount as u32,
1832 limit,
1833 cx,
1834 );
1835 }
1836 CompletionRequestStatus::ToolUseLimitReached => {
1837 thread.tool_use_limit_reached = true;
1838 cx.emit(ThreadEvent::ToolUseLimitReached);
1839 }
1840 }
1841 }
1842 }
1843 }
1844
1845 thread.touch_updated_at();
1846 cx.emit(ThreadEvent::StreamedCompletion);
1847 cx.notify();
1848
1849 thread.auto_capture_telemetry(cx);
1850 Ok(())
1851 })??;
1852
1853 smol::future::yield_now().await;
1854 }
1855
1856 thread.update(cx, |thread, cx| {
1857 thread.last_received_chunk_at = None;
1858 thread
1859 .pending_completions
1860 .retain(|completion| completion.id != pending_completion_id);
1861
1862 // If there is a response without tool use, summarize the message. Otherwise,
1863 // allow two tool uses before summarizing.
1864 if matches!(thread.summary, ThreadSummary::Pending)
1865 && thread.messages.len() >= 2
1866 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1867 {
1868 thread.summarize(cx);
1869 }
1870 })?;
1871
1872 anyhow::Ok(stop_reason)
1873 };
1874
1875 let result = stream_completion.await;
1876 let mut retry_scheduled = false;
1877
1878 thread
1879 .update(cx, |thread, cx| {
1880 thread.finalize_pending_checkpoint(cx);
1881 match result.as_ref() {
1882 Ok(stop_reason) => {
1883 match stop_reason {
1884 StopReason::ToolUse => {
1885 let tool_uses =
1886 thread.use_pending_tools(window, model.clone(), cx);
1887 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1888 }
1889 StopReason::EndTurn | StopReason::MaxTokens => {
1890 thread.project.update(cx, |project, cx| {
1891 project.set_agent_location(None, cx);
1892 });
1893 }
1894 StopReason::Refusal => {
1895 thread.project.update(cx, |project, cx| {
1896 project.set_agent_location(None, cx);
1897 });
1898
1899 // Remove the turn that was refused.
1900 //
1901 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1902 {
1903 let mut messages_to_remove = Vec::new();
1904
1905 for (ix, message) in
1906 thread.messages.iter().enumerate().rev()
1907 {
1908 messages_to_remove.push(message.id);
1909
1910 if message.role == Role::User {
1911 if ix == 0 {
1912 break;
1913 }
1914
1915 if let Some(prev_message) =
1916 thread.messages.get(ix - 1)
1917 {
1918 if prev_message.role == Role::Assistant {
1919 break;
1920 }
1921 }
1922 }
1923 }
1924
1925 for message_id in messages_to_remove {
1926 thread.delete_message(message_id, cx);
1927 }
1928 }
1929
1930 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1931 header: "Language model refusal".into(),
1932 message:
1933 "Model refused to generate content for safety reasons."
1934 .into(),
1935 }));
1936 }
1937 }
1938
1939 // We successfully completed, so cancel any remaining retries.
1940 thread.retry_state = None;
1941 }
1942 Err(error) => {
1943 thread.project.update(cx, |project, cx| {
1944 project.set_agent_location(None, cx);
1945 });
1946
1947 if error.is::<PaymentRequiredError>() {
1948 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1949 } else if let Some(error) =
1950 error.downcast_ref::<ModelRequestLimitReachedError>()
1951 {
1952 cx.emit(ThreadEvent::ShowError(
1953 ThreadError::ModelRequestLimitReached { plan: error.plan },
1954 ));
1955 } else if let Some(completion_error) =
1956 error.downcast_ref::<LanguageModelCompletionError>()
1957 {
1958 match &completion_error {
1959 LanguageModelCompletionError::PromptTooLarge {
1960 tokens, ..
1961 } => {
1962 let tokens = tokens.unwrap_or_else(|| {
1963 // We didn't get an exact token count from the API, so fall back on our estimate.
1964 thread
1965 .total_token_usage()
1966 .map(|usage| usage.total)
1967 .unwrap_or(0)
1968 // We know the context window was exceeded in practice, so if our estimate was
1969 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1970 .max(
1971 model
1972 .max_token_count_for_mode(completion_mode)
1973 .saturating_add(1),
1974 )
1975 });
1976 thread.exceeded_window_error = Some(ExceededWindowError {
1977 model_id: model.id(),
1978 token_count: tokens,
1979 });
1980 cx.notify();
1981 }
1982 _ => {
1983 if let Some(retry_strategy) =
1984 Thread::get_retry_strategy(completion_error)
1985 {
1986 retry_scheduled = thread
1987 .handle_retryable_error_with_delay(
1988 &completion_error,
1989 Some(retry_strategy),
1990 model.clone(),
1991 intent,
1992 window,
1993 cx,
1994 );
1995 }
1996 }
1997 }
1998 }
1999
2000 if !retry_scheduled {
2001 thread.cancel_last_completion(window, cx);
2002 }
2003 }
2004 }
2005
2006 if !retry_scheduled {
2007 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2008 }
2009
2010 if let Some((request_callback, (request, response_events))) = thread
2011 .request_callback
2012 .as_mut()
2013 .zip(request_callback_parameters.as_ref())
2014 {
2015 request_callback(request, response_events);
2016 }
2017
2018 thread.auto_capture_telemetry(cx);
2019
2020 if let Ok(initial_usage) = initial_token_usage {
2021 let usage = thread.cumulative_token_usage - initial_usage;
2022
2023 telemetry::event!(
2024 "Assistant Thread Completion",
2025 thread_id = thread.id().to_string(),
2026 prompt_id = prompt_id,
2027 model = model.telemetry_id(),
2028 model_provider = model.provider_id().to_string(),
2029 input_tokens = usage.input_tokens,
2030 output_tokens = usage.output_tokens,
2031 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2032 cache_read_input_tokens = usage.cache_read_input_tokens,
2033 );
2034 }
2035 })
2036 .ok();
2037 });
2038
2039 self.pending_completions.push(PendingCompletion {
2040 id: pending_completion_id,
2041 queue_state: QueueState::Sending,
2042 _task: task,
2043 });
2044 }
2045
2046 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2047 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2048 println!("No thread summary model");
2049 return;
2050 };
2051
2052 if !model.provider.is_authenticated(cx) {
2053 return;
2054 }
2055
2056 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2057
2058 let request = self.to_summarize_request(
2059 &model.model,
2060 CompletionIntent::ThreadSummarization,
2061 added_user_message.into(),
2062 cx,
2063 );
2064
2065 self.summary = ThreadSummary::Generating;
2066
2067 self.pending_summary = cx.spawn(async move |this, cx| {
2068 let result = async {
2069 let mut messages = model.model.stream_completion(request, &cx).await?;
2070
2071 let mut new_summary = String::new();
2072 while let Some(event) = messages.next().await {
2073 let Ok(event) = event else {
2074 continue;
2075 };
2076 let text = match event {
2077 LanguageModelCompletionEvent::Text(text) => text,
2078 LanguageModelCompletionEvent::StatusUpdate(
2079 CompletionRequestStatus::UsageUpdated { amount, limit },
2080 ) => {
2081 this.update(cx, |thread, cx| {
2082 thread.update_model_request_usage(amount as u32, limit, cx);
2083 })?;
2084 continue;
2085 }
2086 _ => continue,
2087 };
2088
2089 let mut lines = text.lines();
2090 new_summary.extend(lines.next());
2091
2092 // Stop if the LLM generated multiple lines.
2093 if lines.next().is_some() {
2094 break;
2095 }
2096 }
2097
2098 anyhow::Ok(new_summary)
2099 }
2100 .await;
2101
2102 this.update(cx, |this, cx| {
2103 match result {
2104 Ok(new_summary) => {
2105 if new_summary.is_empty() {
2106 this.summary = ThreadSummary::Error;
2107 } else {
2108 this.summary = ThreadSummary::Ready(new_summary.into());
2109 }
2110 }
2111 Err(err) => {
2112 this.summary = ThreadSummary::Error;
2113 log::error!("Failed to generate thread summary: {}", err);
2114 }
2115 }
2116 cx.emit(ThreadEvent::SummaryGenerated);
2117 })
2118 .log_err()?;
2119
2120 Some(())
2121 });
2122 }
2123
2124 fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2125 use LanguageModelCompletionError::*;
2126
2127 // General strategy here:
2128 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2129 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), try multiple times with exponential backoff.
2130 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), just retry once.
2131 match error {
2132 HttpResponseError {
2133 status_code: StatusCode::TOO_MANY_REQUESTS,
2134 ..
2135 } => Some(RetryStrategy::ExponentialBackoff {
2136 initial_delay: BASE_RETRY_DELAY,
2137 max_attempts: MAX_RETRY_ATTEMPTS,
2138 }),
2139 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2140 Some(RetryStrategy::Fixed {
2141 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2142 max_attempts: MAX_RETRY_ATTEMPTS,
2143 })
2144 }
2145 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2146 delay: BASE_RETRY_DELAY,
2147 max_attempts: 1,
2148 }),
2149 ApiReadResponseError { .. }
2150 | HttpSend { .. }
2151 | DeserializeResponse { .. }
2152 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2153 delay: BASE_RETRY_DELAY,
2154 max_attempts: 1,
2155 }),
2156 // Retrying these errors definitely shouldn't help.
2157 HttpResponseError {
2158 status_code:
2159 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2160 ..
2161 }
2162 | SerializeRequest { .. }
2163 | BuildRequestBody { .. }
2164 | PromptTooLarge { .. }
2165 | AuthenticationError { .. }
2166 | PermissionError { .. }
2167 | ApiEndpointNotFound { .. }
2168 | NoApiKey { .. } => None,
2169 // Retry all other 4xx and 5xx errors once.
2170 HttpResponseError { status_code, .. }
2171 if status_code.is_client_error() || status_code.is_server_error() =>
2172 {
2173 Some(RetryStrategy::Fixed {
2174 delay: BASE_RETRY_DELAY,
2175 max_attempts: 1,
2176 })
2177 }
2178 // Conservatively assume that any other errors are non-retryable
2179 HttpResponseError { .. } | Other(..) => None,
2180 }
2181 }
2182
2183 fn handle_retryable_error_with_delay(
2184 &mut self,
2185 error: &LanguageModelCompletionError,
2186 strategy: Option<RetryStrategy>,
2187 model: Arc<dyn LanguageModel>,
2188 intent: CompletionIntent,
2189 window: Option<AnyWindowHandle>,
2190 cx: &mut Context<Self>,
2191 ) -> bool {
2192 let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2193 return false;
2194 };
2195
2196 let max_attempts = match &strategy {
2197 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2198 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2199 };
2200
2201 let retry_state = self.retry_state.get_or_insert(RetryState {
2202 attempt: 0,
2203 max_attempts,
2204 intent,
2205 });
2206
2207 retry_state.attempt += 1;
2208 let attempt = retry_state.attempt;
2209 let max_attempts = retry_state.max_attempts;
2210 let intent = retry_state.intent;
2211
2212 if attempt <= max_attempts {
2213 let delay = match &strategy {
2214 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2215 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2216 Duration::from_secs(delay_secs)
2217 }
2218 RetryStrategy::Fixed { delay, .. } => *delay,
2219 };
2220
2221 // Add a transient message to inform the user
2222 let delay_secs = delay.as_secs();
2223 let retry_message = if max_attempts == 1 {
2224 format!("{error}. Retrying in {delay_secs} seconds...")
2225 } else {
2226 format!(
2227 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2228 in {delay_secs} seconds..."
2229 )
2230 };
2231 log::warn!(
2232 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2233 in {delay_secs} seconds: {error:?}",
2234 );
2235
2236 // Add a UI-only message instead of a regular message
2237 let id = self.next_message_id.post_inc();
2238 self.messages.push(Message {
2239 id,
2240 role: Role::System,
2241 segments: vec![MessageSegment::Text(retry_message)],
2242 loaded_context: LoadedContext::default(),
2243 creases: Vec::new(),
2244 is_hidden: false,
2245 ui_only: true,
2246 });
2247 cx.emit(ThreadEvent::MessageAdded(id));
2248
2249 // Schedule the retry
2250 let thread_handle = cx.entity().downgrade();
2251
2252 cx.spawn(async move |_thread, cx| {
2253 cx.background_executor().timer(delay).await;
2254
2255 thread_handle
2256 .update(cx, |thread, cx| {
2257 // Retry the completion
2258 thread.send_to_model(model, intent, window, cx);
2259 })
2260 .log_err();
2261 })
2262 .detach();
2263
2264 true
2265 } else {
2266 // Max retries exceeded
2267 self.retry_state = None;
2268
2269 // Stop generating since we're giving up on retrying.
2270 self.pending_completions.clear();
2271
2272 false
2273 }
2274 }
2275
2276 pub fn start_generating_detailed_summary_if_needed(
2277 &mut self,
2278 thread_store: WeakEntity<ThreadStore>,
2279 cx: &mut Context<Self>,
2280 ) {
2281 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2282 return;
2283 };
2284
2285 match &*self.detailed_summary_rx.borrow() {
2286 DetailedSummaryState::Generating { message_id, .. }
2287 | DetailedSummaryState::Generated { message_id, .. }
2288 if *message_id == last_message_id =>
2289 {
2290 // Already up-to-date
2291 return;
2292 }
2293 _ => {}
2294 }
2295
2296 let Some(ConfiguredModel { model, provider }) =
2297 LanguageModelRegistry::read_global(cx).thread_summary_model()
2298 else {
2299 return;
2300 };
2301
2302 if !provider.is_authenticated(cx) {
2303 return;
2304 }
2305
2306 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2307
2308 let request = self.to_summarize_request(
2309 &model,
2310 CompletionIntent::ThreadContextSummarization,
2311 added_user_message.into(),
2312 cx,
2313 );
2314
2315 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2316 message_id: last_message_id,
2317 };
2318
2319 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2320 // be better to allow the old task to complete, but this would require logic for choosing
2321 // which result to prefer (the old task could complete after the new one, resulting in a
2322 // stale summary).
2323 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2324 let stream = model.stream_completion_text(request, &cx);
2325 let Some(mut messages) = stream.await.log_err() else {
2326 thread
2327 .update(cx, |thread, _cx| {
2328 *thread.detailed_summary_tx.borrow_mut() =
2329 DetailedSummaryState::NotGenerated;
2330 })
2331 .ok()?;
2332 return None;
2333 };
2334
2335 let mut new_detailed_summary = String::new();
2336
2337 while let Some(chunk) = messages.stream.next().await {
2338 if let Some(chunk) = chunk.log_err() {
2339 new_detailed_summary.push_str(&chunk);
2340 }
2341 }
2342
2343 thread
2344 .update(cx, |thread, _cx| {
2345 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2346 text: new_detailed_summary.into(),
2347 message_id: last_message_id,
2348 };
2349 })
2350 .ok()?;
2351
2352 // Save thread so its summary can be reused later
2353 if let Some(thread) = thread.upgrade() {
2354 if let Ok(Ok(save_task)) = cx.update(|cx| {
2355 thread_store
2356 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2357 }) {
2358 save_task.await.log_err();
2359 }
2360 }
2361
2362 Some(())
2363 });
2364 }
2365
2366 pub async fn wait_for_detailed_summary_or_text(
2367 this: &Entity<Self>,
2368 cx: &mut AsyncApp,
2369 ) -> Option<SharedString> {
2370 let mut detailed_summary_rx = this
2371 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2372 .ok()?;
2373 loop {
2374 match detailed_summary_rx.recv().await? {
2375 DetailedSummaryState::Generating { .. } => {}
2376 DetailedSummaryState::NotGenerated => {
2377 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2378 }
2379 DetailedSummaryState::Generated { text, .. } => return Some(text),
2380 }
2381 }
2382 }
2383
2384 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2385 self.detailed_summary_rx
2386 .borrow()
2387 .text()
2388 .unwrap_or_else(|| self.text().into())
2389 }
2390
2391 pub fn is_generating_detailed_summary(&self) -> bool {
2392 matches!(
2393 &*self.detailed_summary_rx.borrow(),
2394 DetailedSummaryState::Generating { .. }
2395 )
2396 }
2397
2398 pub fn use_pending_tools(
2399 &mut self,
2400 window: Option<AnyWindowHandle>,
2401 model: Arc<dyn LanguageModel>,
2402 cx: &mut Context<Self>,
2403 ) -> Vec<PendingToolUse> {
2404 self.auto_capture_telemetry(cx);
2405 let request =
2406 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2407 let pending_tool_uses = self
2408 .tool_use
2409 .pending_tool_uses()
2410 .into_iter()
2411 .filter(|tool_use| tool_use.status.is_idle())
2412 .cloned()
2413 .collect::<Vec<_>>();
2414
2415 for tool_use in pending_tool_uses.iter() {
2416 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2417 }
2418
2419 pending_tool_uses
2420 }
2421
2422 fn use_pending_tool(
2423 &mut self,
2424 tool_use: PendingToolUse,
2425 request: Arc<LanguageModelRequest>,
2426 model: Arc<dyn LanguageModel>,
2427 window: Option<AnyWindowHandle>,
2428 cx: &mut Context<Self>,
2429 ) {
2430 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2431 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2432 };
2433
2434 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2435 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2436 }
2437
2438 if tool.needs_confirmation(&tool_use.input, cx)
2439 && !AgentSettings::get_global(cx).always_allow_tool_actions
2440 {
2441 self.tool_use.confirm_tool_use(
2442 tool_use.id,
2443 tool_use.ui_text,
2444 tool_use.input,
2445 request,
2446 tool,
2447 );
2448 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2449 } else {
2450 self.run_tool(
2451 tool_use.id,
2452 tool_use.ui_text,
2453 tool_use.input,
2454 request,
2455 tool,
2456 model,
2457 window,
2458 cx,
2459 );
2460 }
2461 }
2462
2463 pub fn handle_hallucinated_tool_use(
2464 &mut self,
2465 tool_use_id: LanguageModelToolUseId,
2466 hallucinated_tool_name: Arc<str>,
2467 window: Option<AnyWindowHandle>,
2468 cx: &mut Context<Thread>,
2469 ) {
2470 let available_tools = self.profile.enabled_tools(cx);
2471
2472 let tool_list = available_tools
2473 .iter()
2474 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2475 .collect::<Vec<_>>()
2476 .join("\n");
2477
2478 let error_message = format!(
2479 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2480 hallucinated_tool_name, tool_list
2481 );
2482
2483 let pending_tool_use = self.tool_use.insert_tool_output(
2484 tool_use_id.clone(),
2485 hallucinated_tool_name,
2486 Err(anyhow!("Missing tool call: {error_message}")),
2487 self.configured_model.as_ref(),
2488 self.completion_mode,
2489 );
2490
2491 cx.emit(ThreadEvent::MissingToolUse {
2492 tool_use_id: tool_use_id.clone(),
2493 ui_text: error_message.into(),
2494 });
2495
2496 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2497 }
2498
2499 pub fn receive_invalid_tool_json(
2500 &mut self,
2501 tool_use_id: LanguageModelToolUseId,
2502 tool_name: Arc<str>,
2503 invalid_json: Arc<str>,
2504 error: String,
2505 window: Option<AnyWindowHandle>,
2506 cx: &mut Context<Thread>,
2507 ) {
2508 log::error!("The model returned invalid input JSON: {invalid_json}");
2509
2510 let pending_tool_use = self.tool_use.insert_tool_output(
2511 tool_use_id.clone(),
2512 tool_name,
2513 Err(anyhow!("Error parsing input JSON: {error}")),
2514 self.configured_model.as_ref(),
2515 self.completion_mode,
2516 );
2517 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2518 pending_tool_use.ui_text.clone()
2519 } else {
2520 log::error!(
2521 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2522 );
2523 format!("Unknown tool {}", tool_use_id).into()
2524 };
2525
2526 cx.emit(ThreadEvent::InvalidToolInput {
2527 tool_use_id: tool_use_id.clone(),
2528 ui_text,
2529 invalid_input_json: invalid_json,
2530 });
2531
2532 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2533 }
2534
2535 pub fn run_tool(
2536 &mut self,
2537 tool_use_id: LanguageModelToolUseId,
2538 ui_text: impl Into<SharedString>,
2539 input: serde_json::Value,
2540 request: Arc<LanguageModelRequest>,
2541 tool: Arc<dyn Tool>,
2542 model: Arc<dyn LanguageModel>,
2543 window: Option<AnyWindowHandle>,
2544 cx: &mut Context<Thread>,
2545 ) {
2546 let task =
2547 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2548 self.tool_use
2549 .run_pending_tool(tool_use_id, ui_text.into(), task);
2550 }
2551
2552 fn spawn_tool_use(
2553 &mut self,
2554 tool_use_id: LanguageModelToolUseId,
2555 request: Arc<LanguageModelRequest>,
2556 input: serde_json::Value,
2557 tool: Arc<dyn Tool>,
2558 model: Arc<dyn LanguageModel>,
2559 window: Option<AnyWindowHandle>,
2560 cx: &mut Context<Thread>,
2561 ) -> Task<()> {
2562 let tool_name: Arc<str> = tool.name().into();
2563
2564 let tool_result = tool.run(
2565 input,
2566 request,
2567 self.project.clone(),
2568 self.action_log.clone(),
2569 model,
2570 window,
2571 cx,
2572 );
2573
2574 // Store the card separately if it exists
2575 if let Some(card) = tool_result.card.clone() {
2576 self.tool_use
2577 .insert_tool_result_card(tool_use_id.clone(), card);
2578 }
2579
2580 cx.spawn({
2581 async move |thread: WeakEntity<Thread>, cx| {
2582 let output = tool_result.output.await;
2583
2584 thread
2585 .update(cx, |thread, cx| {
2586 let pending_tool_use = thread.tool_use.insert_tool_output(
2587 tool_use_id.clone(),
2588 tool_name,
2589 output,
2590 thread.configured_model.as_ref(),
2591 thread.completion_mode,
2592 );
2593 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2594 })
2595 .ok();
2596 }
2597 })
2598 }
2599
2600 fn tool_finished(
2601 &mut self,
2602 tool_use_id: LanguageModelToolUseId,
2603 pending_tool_use: Option<PendingToolUse>,
2604 canceled: bool,
2605 window: Option<AnyWindowHandle>,
2606 cx: &mut Context<Self>,
2607 ) {
2608 if self.all_tools_finished() {
2609 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2610 if !canceled {
2611 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2612 }
2613 self.auto_capture_telemetry(cx);
2614 }
2615 }
2616
2617 cx.emit(ThreadEvent::ToolFinished {
2618 tool_use_id,
2619 pending_tool_use,
2620 });
2621 }
2622
2623 /// Cancels the last pending completion, if there are any pending.
2624 ///
2625 /// Returns whether a completion was canceled.
2626 pub fn cancel_last_completion(
2627 &mut self,
2628 window: Option<AnyWindowHandle>,
2629 cx: &mut Context<Self>,
2630 ) -> bool {
2631 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2632
2633 self.retry_state = None;
2634
2635 for pending_tool_use in self.tool_use.cancel_pending() {
2636 canceled = true;
2637 self.tool_finished(
2638 pending_tool_use.id.clone(),
2639 Some(pending_tool_use),
2640 true,
2641 window,
2642 cx,
2643 );
2644 }
2645
2646 if canceled {
2647 cx.emit(ThreadEvent::CompletionCanceled);
2648
2649 // When canceled, we always want to insert the checkpoint.
2650 // (We skip over finalize_pending_checkpoint, because it
2651 // would conclude we didn't have anything to insert here.)
2652 if let Some(checkpoint) = self.pending_checkpoint.take() {
2653 self.insert_checkpoint(checkpoint, cx);
2654 }
2655 } else {
2656 self.finalize_pending_checkpoint(cx);
2657 }
2658
2659 canceled
2660 }
2661
2662 /// Signals that any in-progress editing should be canceled.
2663 ///
2664 /// This method is used to notify listeners (like ActiveThread) that
2665 /// they should cancel any editing operations.
2666 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2667 cx.emit(ThreadEvent::CancelEditing);
2668 }
2669
2670 pub fn feedback(&self) -> Option<ThreadFeedback> {
2671 self.feedback
2672 }
2673
2674 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2675 self.message_feedback.get(&message_id).copied()
2676 }
2677
2678 pub fn report_message_feedback(
2679 &mut self,
2680 message_id: MessageId,
2681 feedback: ThreadFeedback,
2682 cx: &mut Context<Self>,
2683 ) -> Task<Result<()>> {
2684 if self.message_feedback.get(&message_id) == Some(&feedback) {
2685 return Task::ready(Ok(()));
2686 }
2687
2688 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2689 let serialized_thread = self.serialize(cx);
2690 let thread_id = self.id().clone();
2691 let client = self.project.read(cx).client();
2692
2693 let enabled_tool_names: Vec<String> = self
2694 .profile
2695 .enabled_tools(cx)
2696 .iter()
2697 .map(|(name, _)| name.clone().into())
2698 .collect();
2699
2700 self.message_feedback.insert(message_id, feedback);
2701
2702 cx.notify();
2703
2704 let message_content = self
2705 .message(message_id)
2706 .map(|msg| msg.to_string())
2707 .unwrap_or_default();
2708
2709 cx.background_spawn(async move {
2710 let final_project_snapshot = final_project_snapshot.await;
2711 let serialized_thread = serialized_thread.await?;
2712 let thread_data =
2713 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2714
2715 let rating = match feedback {
2716 ThreadFeedback::Positive => "positive",
2717 ThreadFeedback::Negative => "negative",
2718 };
2719 telemetry::event!(
2720 "Assistant Thread Rated",
2721 rating,
2722 thread_id,
2723 enabled_tool_names,
2724 message_id = message_id.0,
2725 message_content,
2726 thread_data,
2727 final_project_snapshot
2728 );
2729 client.telemetry().flush_events().await;
2730
2731 Ok(())
2732 })
2733 }
2734
2735 pub fn report_feedback(
2736 &mut self,
2737 feedback: ThreadFeedback,
2738 cx: &mut Context<Self>,
2739 ) -> Task<Result<()>> {
2740 let last_assistant_message_id = self
2741 .messages
2742 .iter()
2743 .rev()
2744 .find(|msg| msg.role == Role::Assistant)
2745 .map(|msg| msg.id);
2746
2747 if let Some(message_id) = last_assistant_message_id {
2748 self.report_message_feedback(message_id, feedback, cx)
2749 } else {
2750 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2751 let serialized_thread = self.serialize(cx);
2752 let thread_id = self.id().clone();
2753 let client = self.project.read(cx).client();
2754 self.feedback = Some(feedback);
2755 cx.notify();
2756
2757 cx.background_spawn(async move {
2758 let final_project_snapshot = final_project_snapshot.await;
2759 let serialized_thread = serialized_thread.await?;
2760 let thread_data = serde_json::to_value(serialized_thread)
2761 .unwrap_or_else(|_| serde_json::Value::Null);
2762
2763 let rating = match feedback {
2764 ThreadFeedback::Positive => "positive",
2765 ThreadFeedback::Negative => "negative",
2766 };
2767 telemetry::event!(
2768 "Assistant Thread Rated",
2769 rating,
2770 thread_id,
2771 thread_data,
2772 final_project_snapshot
2773 );
2774 client.telemetry().flush_events().await;
2775
2776 Ok(())
2777 })
2778 }
2779 }
2780
2781 /// Create a snapshot of the current project state including git information and unsaved buffers.
2782 fn project_snapshot(
2783 project: Entity<Project>,
2784 cx: &mut Context<Self>,
2785 ) -> Task<Arc<ProjectSnapshot>> {
2786 let git_store = project.read(cx).git_store().clone();
2787 let worktree_snapshots: Vec<_> = project
2788 .read(cx)
2789 .visible_worktrees(cx)
2790 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2791 .collect();
2792
2793 cx.spawn(async move |_, cx| {
2794 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2795
2796 let mut unsaved_buffers = Vec::new();
2797 cx.update(|app_cx| {
2798 let buffer_store = project.read(app_cx).buffer_store();
2799 for buffer_handle in buffer_store.read(app_cx).buffers() {
2800 let buffer = buffer_handle.read(app_cx);
2801 if buffer.is_dirty() {
2802 if let Some(file) = buffer.file() {
2803 let path = file.path().to_string_lossy().to_string();
2804 unsaved_buffers.push(path);
2805 }
2806 }
2807 }
2808 })
2809 .ok();
2810
2811 Arc::new(ProjectSnapshot {
2812 worktree_snapshots,
2813 unsaved_buffer_paths: unsaved_buffers,
2814 timestamp: Utc::now(),
2815 })
2816 })
2817 }
2818
2819 fn worktree_snapshot(
2820 worktree: Entity<project::Worktree>,
2821 git_store: Entity<GitStore>,
2822 cx: &App,
2823 ) -> Task<WorktreeSnapshot> {
2824 cx.spawn(async move |cx| {
2825 // Get worktree path and snapshot
2826 let worktree_info = cx.update(|app_cx| {
2827 let worktree = worktree.read(app_cx);
2828 let path = worktree.abs_path().to_string_lossy().to_string();
2829 let snapshot = worktree.snapshot();
2830 (path, snapshot)
2831 });
2832
2833 let Ok((worktree_path, _snapshot)) = worktree_info else {
2834 return WorktreeSnapshot {
2835 worktree_path: String::new(),
2836 git_state: None,
2837 };
2838 };
2839
2840 let git_state = git_store
2841 .update(cx, |git_store, cx| {
2842 git_store
2843 .repositories()
2844 .values()
2845 .find(|repo| {
2846 repo.read(cx)
2847 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2848 .is_some()
2849 })
2850 .cloned()
2851 })
2852 .ok()
2853 .flatten()
2854 .map(|repo| {
2855 repo.update(cx, |repo, _| {
2856 let current_branch =
2857 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2858 repo.send_job(None, |state, _| async move {
2859 let RepositoryState::Local { backend, .. } = state else {
2860 return GitState {
2861 remote_url: None,
2862 head_sha: None,
2863 current_branch,
2864 diff: None,
2865 };
2866 };
2867
2868 let remote_url = backend.remote_url("origin");
2869 let head_sha = backend.head_sha().await;
2870 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2871
2872 GitState {
2873 remote_url,
2874 head_sha,
2875 current_branch,
2876 diff,
2877 }
2878 })
2879 })
2880 });
2881
2882 let git_state = match git_state {
2883 Some(git_state) => match git_state.ok() {
2884 Some(git_state) => git_state.await.ok(),
2885 None => None,
2886 },
2887 None => None,
2888 };
2889
2890 WorktreeSnapshot {
2891 worktree_path,
2892 git_state,
2893 }
2894 })
2895 }
2896
2897 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2898 let mut markdown = Vec::new();
2899
2900 let summary = self.summary().or_default();
2901 writeln!(markdown, "# {summary}\n")?;
2902
2903 for message in self.messages() {
2904 writeln!(
2905 markdown,
2906 "## {role}\n",
2907 role = match message.role {
2908 Role::User => "User",
2909 Role::Assistant => "Agent",
2910 Role::System => "System",
2911 }
2912 )?;
2913
2914 if !message.loaded_context.text.is_empty() {
2915 writeln!(markdown, "{}", message.loaded_context.text)?;
2916 }
2917
2918 if !message.loaded_context.images.is_empty() {
2919 writeln!(
2920 markdown,
2921 "\n{} images attached as context.\n",
2922 message.loaded_context.images.len()
2923 )?;
2924 }
2925
2926 for segment in &message.segments {
2927 match segment {
2928 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2929 MessageSegment::Thinking { text, .. } => {
2930 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2931 }
2932 MessageSegment::RedactedThinking(_) => {}
2933 }
2934 }
2935
2936 for tool_use in self.tool_uses_for_message(message.id, cx) {
2937 writeln!(
2938 markdown,
2939 "**Use Tool: {} ({})**",
2940 tool_use.name, tool_use.id
2941 )?;
2942 writeln!(markdown, "```json")?;
2943 writeln!(
2944 markdown,
2945 "{}",
2946 serde_json::to_string_pretty(&tool_use.input)?
2947 )?;
2948 writeln!(markdown, "```")?;
2949 }
2950
2951 for tool_result in self.tool_results_for_message(message.id) {
2952 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2953 if tool_result.is_error {
2954 write!(markdown, " (Error)")?;
2955 }
2956
2957 writeln!(markdown, "**\n")?;
2958 match &tool_result.content {
2959 LanguageModelToolResultContent::Text(text) => {
2960 writeln!(markdown, "{text}")?;
2961 }
2962 LanguageModelToolResultContent::Image(image) => {
2963 writeln!(markdown, "", image.source)?;
2964 }
2965 }
2966
2967 if let Some(output) = tool_result.output.as_ref() {
2968 writeln!(
2969 markdown,
2970 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2971 serde_json::to_string_pretty(output)?
2972 )?;
2973 }
2974 }
2975 }
2976
2977 Ok(String::from_utf8_lossy(&markdown).to_string())
2978 }
2979
2980 pub fn keep_edits_in_range(
2981 &mut self,
2982 buffer: Entity<language::Buffer>,
2983 buffer_range: Range<language::Anchor>,
2984 cx: &mut Context<Self>,
2985 ) {
2986 self.action_log.update(cx, |action_log, cx| {
2987 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2988 });
2989 }
2990
2991 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2992 self.action_log
2993 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2994 }
2995
2996 pub fn reject_edits_in_ranges(
2997 &mut self,
2998 buffer: Entity<language::Buffer>,
2999 buffer_ranges: Vec<Range<language::Anchor>>,
3000 cx: &mut Context<Self>,
3001 ) -> Task<Result<()>> {
3002 self.action_log.update(cx, |action_log, cx| {
3003 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3004 })
3005 }
3006
3007 pub fn action_log(&self) -> &Entity<ActionLog> {
3008 &self.action_log
3009 }
3010
3011 pub fn project(&self) -> &Entity<Project> {
3012 &self.project
3013 }
3014
3015 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3016 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3017 return;
3018 }
3019
3020 let now = Instant::now();
3021 if let Some(last) = self.last_auto_capture_at {
3022 if now.duration_since(last).as_secs() < 10 {
3023 return;
3024 }
3025 }
3026
3027 self.last_auto_capture_at = Some(now);
3028
3029 let thread_id = self.id().clone();
3030 let github_login = self
3031 .project
3032 .read(cx)
3033 .user_store()
3034 .read(cx)
3035 .current_user()
3036 .map(|user| user.github_login.clone());
3037 let client = self.project.read(cx).client();
3038 let serialize_task = self.serialize(cx);
3039
3040 cx.background_executor()
3041 .spawn(async move {
3042 if let Ok(serialized_thread) = serialize_task.await {
3043 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3044 telemetry::event!(
3045 "Agent Thread Auto-Captured",
3046 thread_id = thread_id.to_string(),
3047 thread_data = thread_data,
3048 auto_capture_reason = "tracked_user",
3049 github_login = github_login
3050 );
3051
3052 client.telemetry().flush_events().await;
3053 }
3054 }
3055 })
3056 .detach();
3057 }
3058
3059 pub fn cumulative_token_usage(&self) -> TokenUsage {
3060 self.cumulative_token_usage
3061 }
3062
3063 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3064 let Some(model) = self.configured_model.as_ref() else {
3065 return TotalTokenUsage::default();
3066 };
3067
3068 let max = model
3069 .model
3070 .max_token_count_for_mode(self.completion_mode().into());
3071
3072 let index = self
3073 .messages
3074 .iter()
3075 .position(|msg| msg.id == message_id)
3076 .unwrap_or(0);
3077
3078 if index == 0 {
3079 return TotalTokenUsage { total: 0, max };
3080 }
3081
3082 let token_usage = &self
3083 .request_token_usage
3084 .get(index - 1)
3085 .cloned()
3086 .unwrap_or_default();
3087
3088 TotalTokenUsage {
3089 total: token_usage.total_tokens(),
3090 max,
3091 }
3092 }
3093
3094 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3095 let model = self.configured_model.as_ref()?;
3096
3097 let max = model
3098 .model
3099 .max_token_count_for_mode(self.completion_mode().into());
3100
3101 if let Some(exceeded_error) = &self.exceeded_window_error {
3102 if model.model.id() == exceeded_error.model_id {
3103 return Some(TotalTokenUsage {
3104 total: exceeded_error.token_count,
3105 max,
3106 });
3107 }
3108 }
3109
3110 let total = self
3111 .token_usage_at_last_message()
3112 .unwrap_or_default()
3113 .total_tokens();
3114
3115 Some(TotalTokenUsage { total, max })
3116 }
3117
3118 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3119 self.request_token_usage
3120 .get(self.messages.len().saturating_sub(1))
3121 .or_else(|| self.request_token_usage.last())
3122 .cloned()
3123 }
3124
3125 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3126 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3127 self.request_token_usage
3128 .resize(self.messages.len(), placeholder);
3129
3130 if let Some(last) = self.request_token_usage.last_mut() {
3131 *last = token_usage;
3132 }
3133 }
3134
3135 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3136 self.project.update(cx, |project, cx| {
3137 project.user_store().update(cx, |user_store, cx| {
3138 user_store.update_model_request_usage(
3139 ModelRequestUsage(RequestUsage {
3140 amount: amount as i32,
3141 limit,
3142 }),
3143 cx,
3144 )
3145 })
3146 });
3147 }
3148
3149 pub fn deny_tool_use(
3150 &mut self,
3151 tool_use_id: LanguageModelToolUseId,
3152 tool_name: Arc<str>,
3153 window: Option<AnyWindowHandle>,
3154 cx: &mut Context<Self>,
3155 ) {
3156 let err = Err(anyhow::anyhow!(
3157 "Permission to run tool action denied by user"
3158 ));
3159
3160 self.tool_use.insert_tool_output(
3161 tool_use_id.clone(),
3162 tool_name,
3163 err,
3164 self.configured_model.as_ref(),
3165 self.completion_mode,
3166 );
3167 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3168 }
3169}
3170
3171#[derive(Debug, Clone, Error)]
3172pub enum ThreadError {
3173 #[error("Payment required")]
3174 PaymentRequired,
3175 #[error("Model request limit reached")]
3176 ModelRequestLimitReached { plan: Plan },
3177 #[error("Message {header}: {message}")]
3178 Message {
3179 header: SharedString,
3180 message: SharedString,
3181 },
3182}
3183
3184#[derive(Debug, Clone)]
3185pub enum ThreadEvent {
3186 ShowError(ThreadError),
3187 StreamedCompletion,
3188 ReceivedTextChunk,
3189 NewRequest,
3190 StreamedAssistantText(MessageId, String),
3191 StreamedAssistantThinking(MessageId, String),
3192 StreamedToolUse {
3193 tool_use_id: LanguageModelToolUseId,
3194 ui_text: Arc<str>,
3195 input: serde_json::Value,
3196 },
3197 MissingToolUse {
3198 tool_use_id: LanguageModelToolUseId,
3199 ui_text: Arc<str>,
3200 },
3201 InvalidToolInput {
3202 tool_use_id: LanguageModelToolUseId,
3203 ui_text: Arc<str>,
3204 invalid_input_json: Arc<str>,
3205 },
3206 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3207 MessageAdded(MessageId),
3208 MessageEdited(MessageId),
3209 MessageDeleted(MessageId),
3210 SummaryGenerated,
3211 SummaryChanged,
3212 UsePendingTools {
3213 tool_uses: Vec<PendingToolUse>,
3214 },
3215 ToolFinished {
3216 #[allow(unused)]
3217 tool_use_id: LanguageModelToolUseId,
3218 /// The pending tool use that corresponds to this tool.
3219 pending_tool_use: Option<PendingToolUse>,
3220 },
3221 CheckpointChanged,
3222 ToolConfirmationNeeded,
3223 ToolUseLimitReached,
3224 CancelEditing,
3225 CompletionCanceled,
3226 ProfileChanged,
3227}
3228
3229impl EventEmitter<ThreadEvent> for Thread {}
3230
3231struct PendingCompletion {
3232 id: usize,
3233 queue_state: QueueState,
3234 _task: Task<()>,
3235}
3236
3237#[cfg(test)]
3238mod tests {
3239 use super::*;
3240 use crate::{
3241 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3242 };
3243
3244 // Test-specific constants
3245 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3246 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3247 use assistant_tool::ToolRegistry;
3248 use assistant_tools;
3249 use futures::StreamExt;
3250 use futures::future::BoxFuture;
3251 use futures::stream::BoxStream;
3252 use gpui::TestAppContext;
3253 use http_client;
3254 use indoc::indoc;
3255 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3256 use language_model::{
3257 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3258 LanguageModelProviderName, LanguageModelToolChoice,
3259 };
3260 use parking_lot::Mutex;
3261 use project::{FakeFs, Project};
3262 use prompt_store::PromptBuilder;
3263 use serde_json::json;
3264 use settings::{Settings, SettingsStore};
3265 use std::sync::Arc;
3266 use std::time::Duration;
3267 use theme::ThemeSettings;
3268 use util::path;
3269 use workspace::Workspace;
3270
3271 #[gpui::test]
3272 async fn test_message_with_context(cx: &mut TestAppContext) {
3273 init_test_settings(cx);
3274
3275 let project = create_test_project(
3276 cx,
3277 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3278 )
3279 .await;
3280
3281 let (_workspace, _thread_store, thread, context_store, model) =
3282 setup_test_environment(cx, project.clone()).await;
3283
3284 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3285 .await
3286 .unwrap();
3287
3288 let context =
3289 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3290 let loaded_context = cx
3291 .update(|cx| load_context(vec![context], &project, &None, cx))
3292 .await;
3293
3294 // Insert user message with context
3295 let message_id = thread.update(cx, |thread, cx| {
3296 thread.insert_user_message(
3297 "Please explain this code",
3298 loaded_context,
3299 None,
3300 Vec::new(),
3301 cx,
3302 )
3303 });
3304
3305 // Check content and context in message object
3306 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3307
3308 // Use different path format strings based on platform for the test
3309 #[cfg(windows)]
3310 let path_part = r"test\code.rs";
3311 #[cfg(not(windows))]
3312 let path_part = "test/code.rs";
3313
3314 let expected_context = format!(
3315 r#"
3316<context>
3317The following items were attached by the user. They are up-to-date and don't need to be re-read.
3318
3319<files>
3320```rs {path_part}
3321fn main() {{
3322 println!("Hello, world!");
3323}}
3324```
3325</files>
3326</context>
3327"#
3328 );
3329
3330 assert_eq!(message.role, Role::User);
3331 assert_eq!(message.segments.len(), 1);
3332 assert_eq!(
3333 message.segments[0],
3334 MessageSegment::Text("Please explain this code".to_string())
3335 );
3336 assert_eq!(message.loaded_context.text, expected_context);
3337
3338 // Check message in request
3339 let request = thread.update(cx, |thread, cx| {
3340 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3341 });
3342
3343 assert_eq!(request.messages.len(), 2);
3344 let expected_full_message = format!("{}Please explain this code", expected_context);
3345 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3346 }
3347
3348 #[gpui::test]
3349 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3350 init_test_settings(cx);
3351
3352 let project = create_test_project(
3353 cx,
3354 json!({
3355 "file1.rs": "fn function1() {}\n",
3356 "file2.rs": "fn function2() {}\n",
3357 "file3.rs": "fn function3() {}\n",
3358 "file4.rs": "fn function4() {}\n",
3359 }),
3360 )
3361 .await;
3362
3363 let (_, _thread_store, thread, context_store, model) =
3364 setup_test_environment(cx, project.clone()).await;
3365
3366 // First message with context 1
3367 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3368 .await
3369 .unwrap();
3370 let new_contexts = context_store.update(cx, |store, cx| {
3371 store.new_context_for_thread(thread.read(cx), None)
3372 });
3373 assert_eq!(new_contexts.len(), 1);
3374 let loaded_context = cx
3375 .update(|cx| load_context(new_contexts, &project, &None, cx))
3376 .await;
3377 let message1_id = thread.update(cx, |thread, cx| {
3378 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3379 });
3380
3381 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3382 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3383 .await
3384 .unwrap();
3385 let new_contexts = context_store.update(cx, |store, cx| {
3386 store.new_context_for_thread(thread.read(cx), None)
3387 });
3388 assert_eq!(new_contexts.len(), 1);
3389 let loaded_context = cx
3390 .update(|cx| load_context(new_contexts, &project, &None, cx))
3391 .await;
3392 let message2_id = thread.update(cx, |thread, cx| {
3393 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3394 });
3395
3396 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3397 //
3398 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3399 .await
3400 .unwrap();
3401 let new_contexts = context_store.update(cx, |store, cx| {
3402 store.new_context_for_thread(thread.read(cx), None)
3403 });
3404 assert_eq!(new_contexts.len(), 1);
3405 let loaded_context = cx
3406 .update(|cx| load_context(new_contexts, &project, &None, cx))
3407 .await;
3408 let message3_id = thread.update(cx, |thread, cx| {
3409 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3410 });
3411
3412 // Check what contexts are included in each message
3413 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3414 (
3415 thread.message(message1_id).unwrap().clone(),
3416 thread.message(message2_id).unwrap().clone(),
3417 thread.message(message3_id).unwrap().clone(),
3418 )
3419 });
3420
3421 // First message should include context 1
3422 assert!(message1.loaded_context.text.contains("file1.rs"));
3423
3424 // Second message should include only context 2 (not 1)
3425 assert!(!message2.loaded_context.text.contains("file1.rs"));
3426 assert!(message2.loaded_context.text.contains("file2.rs"));
3427
3428 // Third message should include only context 3 (not 1 or 2)
3429 assert!(!message3.loaded_context.text.contains("file1.rs"));
3430 assert!(!message3.loaded_context.text.contains("file2.rs"));
3431 assert!(message3.loaded_context.text.contains("file3.rs"));
3432
3433 // Check entire request to make sure all contexts are properly included
3434 let request = thread.update(cx, |thread, cx| {
3435 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3436 });
3437
3438 // The request should contain all 3 messages
3439 assert_eq!(request.messages.len(), 4);
3440
3441 // Check that the contexts are properly formatted in each message
3442 assert!(request.messages[1].string_contents().contains("file1.rs"));
3443 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3444 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3445
3446 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3447 assert!(request.messages[2].string_contents().contains("file2.rs"));
3448 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3449
3450 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3451 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3452 assert!(request.messages[3].string_contents().contains("file3.rs"));
3453
3454 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3455 .await
3456 .unwrap();
3457 let new_contexts = context_store.update(cx, |store, cx| {
3458 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3459 });
3460 assert_eq!(new_contexts.len(), 3);
3461 let loaded_context = cx
3462 .update(|cx| load_context(new_contexts, &project, &None, cx))
3463 .await
3464 .loaded_context;
3465
3466 assert!(!loaded_context.text.contains("file1.rs"));
3467 assert!(loaded_context.text.contains("file2.rs"));
3468 assert!(loaded_context.text.contains("file3.rs"));
3469 assert!(loaded_context.text.contains("file4.rs"));
3470
3471 let new_contexts = context_store.update(cx, |store, cx| {
3472 // Remove file4.rs
3473 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3474 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3475 });
3476 assert_eq!(new_contexts.len(), 2);
3477 let loaded_context = cx
3478 .update(|cx| load_context(new_contexts, &project, &None, cx))
3479 .await
3480 .loaded_context;
3481
3482 assert!(!loaded_context.text.contains("file1.rs"));
3483 assert!(loaded_context.text.contains("file2.rs"));
3484 assert!(loaded_context.text.contains("file3.rs"));
3485 assert!(!loaded_context.text.contains("file4.rs"));
3486
3487 let new_contexts = context_store.update(cx, |store, cx| {
3488 // Remove file3.rs
3489 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3490 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3491 });
3492 assert_eq!(new_contexts.len(), 1);
3493 let loaded_context = cx
3494 .update(|cx| load_context(new_contexts, &project, &None, cx))
3495 .await
3496 .loaded_context;
3497
3498 assert!(!loaded_context.text.contains("file1.rs"));
3499 assert!(loaded_context.text.contains("file2.rs"));
3500 assert!(!loaded_context.text.contains("file3.rs"));
3501 assert!(!loaded_context.text.contains("file4.rs"));
3502 }
3503
3504 #[gpui::test]
3505 async fn test_message_without_files(cx: &mut TestAppContext) {
3506 init_test_settings(cx);
3507
3508 let project = create_test_project(
3509 cx,
3510 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3511 )
3512 .await;
3513
3514 let (_, _thread_store, thread, _context_store, model) =
3515 setup_test_environment(cx, project.clone()).await;
3516
3517 // Insert user message without any context (empty context vector)
3518 let message_id = thread.update(cx, |thread, cx| {
3519 thread.insert_user_message(
3520 "What is the best way to learn Rust?",
3521 ContextLoadResult::default(),
3522 None,
3523 Vec::new(),
3524 cx,
3525 )
3526 });
3527
3528 // Check content and context in message object
3529 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3530
3531 // Context should be empty when no files are included
3532 assert_eq!(message.role, Role::User);
3533 assert_eq!(message.segments.len(), 1);
3534 assert_eq!(
3535 message.segments[0],
3536 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3537 );
3538 assert_eq!(message.loaded_context.text, "");
3539
3540 // Check message in request
3541 let request = thread.update(cx, |thread, cx| {
3542 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3543 });
3544
3545 assert_eq!(request.messages.len(), 2);
3546 assert_eq!(
3547 request.messages[1].string_contents(),
3548 "What is the best way to learn Rust?"
3549 );
3550
3551 // Add second message, also without context
3552 let message2_id = thread.update(cx, |thread, cx| {
3553 thread.insert_user_message(
3554 "Are there any good books?",
3555 ContextLoadResult::default(),
3556 None,
3557 Vec::new(),
3558 cx,
3559 )
3560 });
3561
3562 let message2 =
3563 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3564 assert_eq!(message2.loaded_context.text, "");
3565
3566 // Check that both messages appear in the request
3567 let request = thread.update(cx, |thread, cx| {
3568 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3569 });
3570
3571 assert_eq!(request.messages.len(), 3);
3572 assert_eq!(
3573 request.messages[1].string_contents(),
3574 "What is the best way to learn Rust?"
3575 );
3576 assert_eq!(
3577 request.messages[2].string_contents(),
3578 "Are there any good books?"
3579 );
3580 }
3581
3582 #[gpui::test]
3583 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3584 init_test_settings(cx);
3585
3586 let project = create_test_project(
3587 cx,
3588 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3589 )
3590 .await;
3591
3592 let (_workspace, _thread_store, thread, context_store, model) =
3593 setup_test_environment(cx, project.clone()).await;
3594
3595 // Add a buffer to the context. This will be a tracked buffer
3596 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3597 .await
3598 .unwrap();
3599
3600 let context = context_store
3601 .read_with(cx, |store, _| store.context().next().cloned())
3602 .unwrap();
3603 let loaded_context = cx
3604 .update(|cx| load_context(vec![context], &project, &None, cx))
3605 .await;
3606
3607 // Insert user message and assistant response
3608 thread.update(cx, |thread, cx| {
3609 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3610 thread.insert_assistant_message(
3611 vec![MessageSegment::Text("This code prints 42.".into())],
3612 cx,
3613 );
3614 });
3615
3616 // We shouldn't have a stale buffer notification yet
3617 let notifications = thread.read_with(cx, |thread, _| {
3618 find_tool_uses(thread, "project_notifications")
3619 });
3620 assert!(
3621 notifications.is_empty(),
3622 "Should not have stale buffer notification before buffer is modified"
3623 );
3624
3625 // Modify the buffer
3626 buffer.update(cx, |buffer, cx| {
3627 buffer.edit(
3628 [(1..1, "\n println!(\"Added a new line\");\n")],
3629 None,
3630 cx,
3631 );
3632 });
3633
3634 // Insert another user message
3635 thread.update(cx, |thread, cx| {
3636 thread.insert_user_message(
3637 "What does the code do now?",
3638 ContextLoadResult::default(),
3639 None,
3640 Vec::new(),
3641 cx,
3642 )
3643 });
3644
3645 // Check for the stale buffer warning
3646 thread.update(cx, |thread, cx| {
3647 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3648 });
3649
3650 let notifications = thread.read_with(cx, |thread, _cx| {
3651 find_tool_uses(thread, "project_notifications")
3652 });
3653
3654 let [notification] = notifications.as_slice() else {
3655 panic!("Should have a `project_notifications` tool use");
3656 };
3657
3658 let Some(notification_content) = notification.content.to_str() else {
3659 panic!("`project_notifications` should return text");
3660 };
3661
3662 let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3663
3664 These files have changed since the last read:
3665 - code.rs
3666 "};
3667 assert_eq!(notification_content, expected_content);
3668
3669 // Insert another user message and flush notifications again
3670 thread.update(cx, |thread, cx| {
3671 thread.insert_user_message(
3672 "Can you tell me more?",
3673 ContextLoadResult::default(),
3674 None,
3675 Vec::new(),
3676 cx,
3677 )
3678 });
3679
3680 thread.update(cx, |thread, cx| {
3681 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3682 });
3683
3684 // There should be no new notifications (we already flushed one)
3685 let notifications = thread.read_with(cx, |thread, _cx| {
3686 find_tool_uses(thread, "project_notifications")
3687 });
3688
3689 assert_eq!(
3690 notifications.len(),
3691 1,
3692 "Should still have only one notification after second flush - no duplicates"
3693 );
3694 }
3695
3696 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3697 thread
3698 .messages()
3699 .flat_map(|message| {
3700 thread
3701 .tool_results_for_message(message.id)
3702 .into_iter()
3703 .filter(|result| result.tool_name == tool_name.into())
3704 .cloned()
3705 .collect::<Vec<_>>()
3706 })
3707 .collect()
3708 }
3709
3710 #[gpui::test]
3711 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3712 init_test_settings(cx);
3713
3714 let project = create_test_project(
3715 cx,
3716 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3717 )
3718 .await;
3719
3720 let (_workspace, thread_store, thread, _context_store, _model) =
3721 setup_test_environment(cx, project.clone()).await;
3722
3723 // Check that we are starting with the default profile
3724 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3725 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3726 assert_eq!(
3727 profile,
3728 AgentProfile::new(AgentProfileId::default(), tool_set)
3729 );
3730 }
3731
3732 #[gpui::test]
3733 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3734 init_test_settings(cx);
3735
3736 let project = create_test_project(
3737 cx,
3738 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3739 )
3740 .await;
3741
3742 let (_workspace, thread_store, thread, _context_store, _model) =
3743 setup_test_environment(cx, project.clone()).await;
3744
3745 // Profile gets serialized with default values
3746 let serialized = thread
3747 .update(cx, |thread, cx| thread.serialize(cx))
3748 .await
3749 .unwrap();
3750
3751 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3752
3753 let deserialized = cx.update(|cx| {
3754 thread.update(cx, |thread, cx| {
3755 Thread::deserialize(
3756 thread.id.clone(),
3757 serialized,
3758 thread.project.clone(),
3759 thread.tools.clone(),
3760 thread.prompt_builder.clone(),
3761 thread.project_context.clone(),
3762 None,
3763 cx,
3764 )
3765 })
3766 });
3767 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3768
3769 assert_eq!(
3770 deserialized.profile,
3771 AgentProfile::new(AgentProfileId::default(), tool_set)
3772 );
3773 }
3774
3775 #[gpui::test]
3776 async fn test_temperature_setting(cx: &mut TestAppContext) {
3777 init_test_settings(cx);
3778
3779 let project = create_test_project(
3780 cx,
3781 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3782 )
3783 .await;
3784
3785 let (_workspace, _thread_store, thread, _context_store, model) =
3786 setup_test_environment(cx, project.clone()).await;
3787
3788 // Both model and provider
3789 cx.update(|cx| {
3790 AgentSettings::override_global(
3791 AgentSettings {
3792 model_parameters: vec![LanguageModelParameters {
3793 provider: Some(model.provider_id().0.to_string().into()),
3794 model: Some(model.id().0.clone()),
3795 temperature: Some(0.66),
3796 }],
3797 ..AgentSettings::get_global(cx).clone()
3798 },
3799 cx,
3800 );
3801 });
3802
3803 let request = thread.update(cx, |thread, cx| {
3804 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3805 });
3806 assert_eq!(request.temperature, Some(0.66));
3807
3808 // Only model
3809 cx.update(|cx| {
3810 AgentSettings::override_global(
3811 AgentSettings {
3812 model_parameters: vec![LanguageModelParameters {
3813 provider: None,
3814 model: Some(model.id().0.clone()),
3815 temperature: Some(0.66),
3816 }],
3817 ..AgentSettings::get_global(cx).clone()
3818 },
3819 cx,
3820 );
3821 });
3822
3823 let request = thread.update(cx, |thread, cx| {
3824 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3825 });
3826 assert_eq!(request.temperature, Some(0.66));
3827
3828 // Only provider
3829 cx.update(|cx| {
3830 AgentSettings::override_global(
3831 AgentSettings {
3832 model_parameters: vec![LanguageModelParameters {
3833 provider: Some(model.provider_id().0.to_string().into()),
3834 model: None,
3835 temperature: Some(0.66),
3836 }],
3837 ..AgentSettings::get_global(cx).clone()
3838 },
3839 cx,
3840 );
3841 });
3842
3843 let request = thread.update(cx, |thread, cx| {
3844 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3845 });
3846 assert_eq!(request.temperature, Some(0.66));
3847
3848 // Same model name, different provider
3849 cx.update(|cx| {
3850 AgentSettings::override_global(
3851 AgentSettings {
3852 model_parameters: vec![LanguageModelParameters {
3853 provider: Some("anthropic".into()),
3854 model: Some(model.id().0.clone()),
3855 temperature: Some(0.66),
3856 }],
3857 ..AgentSettings::get_global(cx).clone()
3858 },
3859 cx,
3860 );
3861 });
3862
3863 let request = thread.update(cx, |thread, cx| {
3864 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3865 });
3866 assert_eq!(request.temperature, None);
3867 }
3868
3869 #[gpui::test]
3870 async fn test_thread_summary(cx: &mut TestAppContext) {
3871 init_test_settings(cx);
3872
3873 let project = create_test_project(cx, json!({})).await;
3874
3875 let (_, _thread_store, thread, _context_store, model) =
3876 setup_test_environment(cx, project.clone()).await;
3877
3878 // Initial state should be pending
3879 thread.read_with(cx, |thread, _| {
3880 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3881 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3882 });
3883
3884 // Manually setting the summary should not be allowed in this state
3885 thread.update(cx, |thread, cx| {
3886 thread.set_summary("This should not work", cx);
3887 });
3888
3889 thread.read_with(cx, |thread, _| {
3890 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3891 });
3892
3893 // Send a message
3894 thread.update(cx, |thread, cx| {
3895 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3896 thread.send_to_model(
3897 model.clone(),
3898 CompletionIntent::ThreadSummarization,
3899 None,
3900 cx,
3901 );
3902 });
3903
3904 let fake_model = model.as_fake();
3905 simulate_successful_response(&fake_model, cx);
3906
3907 // Should start generating summary when there are >= 2 messages
3908 thread.read_with(cx, |thread, _| {
3909 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3910 });
3911
3912 // Should not be able to set the summary while generating
3913 thread.update(cx, |thread, cx| {
3914 thread.set_summary("This should not work either", cx);
3915 });
3916
3917 thread.read_with(cx, |thread, _| {
3918 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3919 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3920 });
3921
3922 cx.run_until_parked();
3923 fake_model.stream_last_completion_response("Brief");
3924 fake_model.stream_last_completion_response(" Introduction");
3925 fake_model.end_last_completion_stream();
3926 cx.run_until_parked();
3927
3928 // Summary should be set
3929 thread.read_with(cx, |thread, _| {
3930 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3931 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3932 });
3933
3934 // Now we should be able to set a summary
3935 thread.update(cx, |thread, cx| {
3936 thread.set_summary("Brief Intro", cx);
3937 });
3938
3939 thread.read_with(cx, |thread, _| {
3940 assert_eq!(thread.summary().or_default(), "Brief Intro");
3941 });
3942
3943 // Test setting an empty summary (should default to DEFAULT)
3944 thread.update(cx, |thread, cx| {
3945 thread.set_summary("", cx);
3946 });
3947
3948 thread.read_with(cx, |thread, _| {
3949 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3950 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3951 });
3952 }
3953
3954 #[gpui::test]
3955 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3956 init_test_settings(cx);
3957
3958 let project = create_test_project(cx, json!({})).await;
3959
3960 let (_, _thread_store, thread, _context_store, model) =
3961 setup_test_environment(cx, project.clone()).await;
3962
3963 test_summarize_error(&model, &thread, cx);
3964
3965 // Now we should be able to set a summary
3966 thread.update(cx, |thread, cx| {
3967 thread.set_summary("Brief Intro", cx);
3968 });
3969
3970 thread.read_with(cx, |thread, _| {
3971 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3972 assert_eq!(thread.summary().or_default(), "Brief Intro");
3973 });
3974 }
3975
3976 #[gpui::test]
3977 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3978 init_test_settings(cx);
3979
3980 let project = create_test_project(cx, json!({})).await;
3981
3982 let (_, _thread_store, thread, _context_store, model) =
3983 setup_test_environment(cx, project.clone()).await;
3984
3985 test_summarize_error(&model, &thread, cx);
3986
3987 // Sending another message should not trigger another summarize request
3988 thread.update(cx, |thread, cx| {
3989 thread.insert_user_message(
3990 "How are you?",
3991 ContextLoadResult::default(),
3992 None,
3993 vec![],
3994 cx,
3995 );
3996 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3997 });
3998
3999 let fake_model = model.as_fake();
4000 simulate_successful_response(&fake_model, cx);
4001
4002 thread.read_with(cx, |thread, _| {
4003 // State is still Error, not Generating
4004 assert!(matches!(thread.summary(), ThreadSummary::Error));
4005 });
4006
4007 // But the summarize request can be invoked manually
4008 thread.update(cx, |thread, cx| {
4009 thread.summarize(cx);
4010 });
4011
4012 thread.read_with(cx, |thread, _| {
4013 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4014 });
4015
4016 cx.run_until_parked();
4017 fake_model.stream_last_completion_response("A successful summary");
4018 fake_model.end_last_completion_stream();
4019 cx.run_until_parked();
4020
4021 thread.read_with(cx, |thread, _| {
4022 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4023 assert_eq!(thread.summary().or_default(), "A successful summary");
4024 });
4025 }
4026
4027 // Helper to create a model that returns errors
4028 enum TestError {
4029 Overloaded,
4030 InternalServerError,
4031 }
4032
4033 struct ErrorInjector {
4034 inner: Arc<FakeLanguageModel>,
4035 error_type: TestError,
4036 }
4037
4038 impl ErrorInjector {
4039 fn new(error_type: TestError) -> Self {
4040 Self {
4041 inner: Arc::new(FakeLanguageModel::default()),
4042 error_type,
4043 }
4044 }
4045 }
4046
4047 impl LanguageModel for ErrorInjector {
4048 fn id(&self) -> LanguageModelId {
4049 self.inner.id()
4050 }
4051
4052 fn name(&self) -> LanguageModelName {
4053 self.inner.name()
4054 }
4055
4056 fn provider_id(&self) -> LanguageModelProviderId {
4057 self.inner.provider_id()
4058 }
4059
4060 fn provider_name(&self) -> LanguageModelProviderName {
4061 self.inner.provider_name()
4062 }
4063
4064 fn supports_tools(&self) -> bool {
4065 self.inner.supports_tools()
4066 }
4067
4068 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4069 self.inner.supports_tool_choice(choice)
4070 }
4071
4072 fn supports_images(&self) -> bool {
4073 self.inner.supports_images()
4074 }
4075
4076 fn telemetry_id(&self) -> String {
4077 self.inner.telemetry_id()
4078 }
4079
4080 fn max_token_count(&self) -> u64 {
4081 self.inner.max_token_count()
4082 }
4083
4084 fn count_tokens(
4085 &self,
4086 request: LanguageModelRequest,
4087 cx: &App,
4088 ) -> BoxFuture<'static, Result<u64>> {
4089 self.inner.count_tokens(request, cx)
4090 }
4091
4092 fn stream_completion(
4093 &self,
4094 _request: LanguageModelRequest,
4095 _cx: &AsyncApp,
4096 ) -> BoxFuture<
4097 'static,
4098 Result<
4099 BoxStream<
4100 'static,
4101 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4102 >,
4103 LanguageModelCompletionError,
4104 >,
4105 > {
4106 let error = match self.error_type {
4107 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4108 provider: self.provider_name(),
4109 retry_after: None,
4110 },
4111 TestError::InternalServerError => {
4112 LanguageModelCompletionError::ApiInternalServerError {
4113 provider: self.provider_name(),
4114 message: "I'm a teapot orbiting the sun".to_string(),
4115 }
4116 }
4117 };
4118 async move {
4119 let stream = futures::stream::once(async move { Err(error) });
4120 Ok(stream.boxed())
4121 }
4122 .boxed()
4123 }
4124
4125 fn as_fake(&self) -> &FakeLanguageModel {
4126 &self.inner
4127 }
4128 }
4129
4130 #[gpui::test]
4131 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4132 init_test_settings(cx);
4133
4134 let project = create_test_project(cx, json!({})).await;
4135 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4136
4137 // Create model that returns overloaded error
4138 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4139
4140 // Insert a user message
4141 thread.update(cx, |thread, cx| {
4142 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4143 });
4144
4145 // Start completion
4146 thread.update(cx, |thread, cx| {
4147 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4148 });
4149
4150 cx.run_until_parked();
4151
4152 thread.read_with(cx, |thread, _| {
4153 assert!(thread.retry_state.is_some(), "Should have retry state");
4154 let retry_state = thread.retry_state.as_ref().unwrap();
4155 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4156 assert_eq!(
4157 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4158 "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4159 );
4160 });
4161
4162 // Check that a retry message was added
4163 thread.read_with(cx, |thread, _| {
4164 let mut messages = thread.messages();
4165 assert!(
4166 messages.any(|msg| {
4167 msg.role == Role::System
4168 && msg.ui_only
4169 && msg.segments.iter().any(|seg| {
4170 if let MessageSegment::Text(text) = seg {
4171 text.contains("overloaded")
4172 && text
4173 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4174 } else {
4175 false
4176 }
4177 })
4178 }),
4179 "Should have added a system retry message"
4180 );
4181 });
4182
4183 let retry_count = thread.update(cx, |thread, _| {
4184 thread
4185 .messages
4186 .iter()
4187 .filter(|m| {
4188 m.ui_only
4189 && m.segments.iter().any(|s| {
4190 if let MessageSegment::Text(text) = s {
4191 text.contains("Retrying") && text.contains("seconds")
4192 } else {
4193 false
4194 }
4195 })
4196 })
4197 .count()
4198 });
4199
4200 assert_eq!(retry_count, 1, "Should have one retry message");
4201 }
4202
4203 #[gpui::test]
4204 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4205 init_test_settings(cx);
4206
4207 let project = create_test_project(cx, json!({})).await;
4208 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4209
4210 // Create model that returns internal server error
4211 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4212
4213 // Insert a user message
4214 thread.update(cx, |thread, cx| {
4215 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4216 });
4217
4218 // Start completion
4219 thread.update(cx, |thread, cx| {
4220 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4221 });
4222
4223 cx.run_until_parked();
4224
4225 // Check retry state on thread
4226 thread.read_with(cx, |thread, _| {
4227 assert!(thread.retry_state.is_some(), "Should have retry state");
4228 let retry_state = thread.retry_state.as_ref().unwrap();
4229 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4230 assert_eq!(
4231 retry_state.max_attempts, 1,
4232 "Should have correct max attempts"
4233 );
4234 });
4235
4236 // Check that a retry message was added with provider name
4237 thread.read_with(cx, |thread, _| {
4238 let mut messages = thread.messages();
4239 assert!(
4240 messages.any(|msg| {
4241 msg.role == Role::System
4242 && msg.ui_only
4243 && msg.segments.iter().any(|seg| {
4244 if let MessageSegment::Text(text) = seg {
4245 text.contains("internal")
4246 && text.contains("Fake")
4247 && text.contains("Retrying in")
4248 && !text.contains("attempt")
4249 } else {
4250 false
4251 }
4252 })
4253 }),
4254 "Should have added a system retry message with provider name"
4255 );
4256 });
4257
4258 // Count retry messages
4259 let retry_count = thread.update(cx, |thread, _| {
4260 thread
4261 .messages
4262 .iter()
4263 .filter(|m| {
4264 m.ui_only
4265 && m.segments.iter().any(|s| {
4266 if let MessageSegment::Text(text) = s {
4267 text.contains("Retrying") && text.contains("seconds")
4268 } else {
4269 false
4270 }
4271 })
4272 })
4273 .count()
4274 });
4275
4276 assert_eq!(retry_count, 1, "Should have one retry message");
4277 }
4278
4279 #[gpui::test]
4280 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4281 init_test_settings(cx);
4282
4283 let project = create_test_project(cx, json!({})).await;
4284 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4285
4286 // Create model that returns internal server error
4287 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4288
4289 // Insert a user message
4290 thread.update(cx, |thread, cx| {
4291 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4292 });
4293
4294 // Track retry events and completion count
4295 // Track completion events
4296 let completion_count = Arc::new(Mutex::new(0));
4297 let completion_count_clone = completion_count.clone();
4298
4299 let _subscription = thread.update(cx, |_, cx| {
4300 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4301 if let ThreadEvent::NewRequest = event {
4302 *completion_count_clone.lock() += 1;
4303 }
4304 })
4305 });
4306
4307 // First attempt
4308 thread.update(cx, |thread, cx| {
4309 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4310 });
4311 cx.run_until_parked();
4312
4313 // Should have scheduled first retry - count retry messages
4314 let retry_count = thread.update(cx, |thread, _| {
4315 thread
4316 .messages
4317 .iter()
4318 .filter(|m| {
4319 m.ui_only
4320 && m.segments.iter().any(|s| {
4321 if let MessageSegment::Text(text) = s {
4322 text.contains("Retrying") && text.contains("seconds")
4323 } else {
4324 false
4325 }
4326 })
4327 })
4328 .count()
4329 });
4330 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4331
4332 // Check retry state
4333 thread.read_with(cx, |thread, _| {
4334 assert!(thread.retry_state.is_some(), "Should have retry state");
4335 let retry_state = thread.retry_state.as_ref().unwrap();
4336 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4337 assert_eq!(
4338 retry_state.max_attempts, 1,
4339 "Internal server errors should only retry once"
4340 );
4341 });
4342
4343 // Advance clock for first retry
4344 cx.executor().advance_clock(BASE_RETRY_DELAY);
4345 cx.run_until_parked();
4346
4347 // Should have scheduled second retry - count retry messages
4348 let retry_count = thread.update(cx, |thread, _| {
4349 thread
4350 .messages
4351 .iter()
4352 .filter(|m| {
4353 m.ui_only
4354 && m.segments.iter().any(|s| {
4355 if let MessageSegment::Text(text) = s {
4356 text.contains("Retrying") && text.contains("seconds")
4357 } else {
4358 false
4359 }
4360 })
4361 })
4362 .count()
4363 });
4364 assert_eq!(
4365 retry_count, 1,
4366 "Should have only one retry for internal server errors"
4367 );
4368
4369 // For internal server errors, we only retry once and then give up
4370 // Check that retry_state is cleared after the single retry
4371 thread.read_with(cx, |thread, _| {
4372 assert!(
4373 thread.retry_state.is_none(),
4374 "Retry state should be cleared after single retry"
4375 );
4376 });
4377
4378 // Verify total attempts (1 initial + 1 retry)
4379 assert_eq!(
4380 *completion_count.lock(),
4381 2,
4382 "Should have attempted once plus 1 retry"
4383 );
4384 }
4385
4386 #[gpui::test]
4387 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4388 init_test_settings(cx);
4389
4390 let project = create_test_project(cx, json!({})).await;
4391 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4392
4393 // Create model that returns overloaded error
4394 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4395
4396 // Insert a user message
4397 thread.update(cx, |thread, cx| {
4398 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4399 });
4400
4401 // Track events
4402 let stopped_with_error = Arc::new(Mutex::new(false));
4403 let stopped_with_error_clone = stopped_with_error.clone();
4404
4405 let _subscription = thread.update(cx, |_, cx| {
4406 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4407 if let ThreadEvent::Stopped(Err(_)) = event {
4408 *stopped_with_error_clone.lock() = true;
4409 }
4410 })
4411 });
4412
4413 // Start initial completion
4414 thread.update(cx, |thread, cx| {
4415 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4416 });
4417 cx.run_until_parked();
4418
4419 // Advance through all retries
4420 for _ in 0..MAX_RETRY_ATTEMPTS {
4421 cx.executor().advance_clock(BASE_RETRY_DELAY);
4422 cx.run_until_parked();
4423 }
4424
4425 let retry_count = thread.update(cx, |thread, _| {
4426 thread
4427 .messages
4428 .iter()
4429 .filter(|m| {
4430 m.ui_only
4431 && m.segments.iter().any(|s| {
4432 if let MessageSegment::Text(text) = s {
4433 text.contains("Retrying") && text.contains("seconds")
4434 } else {
4435 false
4436 }
4437 })
4438 })
4439 .count()
4440 });
4441
4442 // After max retries, should emit Stopped(Err(...)) event
4443 assert_eq!(
4444 retry_count, MAX_RETRY_ATTEMPTS as usize,
4445 "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4446 );
4447 assert!(
4448 *stopped_with_error.lock(),
4449 "Should emit Stopped(Err(...)) event after max retries exceeded"
4450 );
4451
4452 // Retry state should be cleared
4453 thread.read_with(cx, |thread, _| {
4454 assert!(
4455 thread.retry_state.is_none(),
4456 "Retry state should be cleared after max retries"
4457 );
4458
4459 // Verify we have the expected number of retry messages
4460 let retry_messages = thread
4461 .messages
4462 .iter()
4463 .filter(|msg| msg.ui_only && msg.role == Role::System)
4464 .count();
4465 assert_eq!(
4466 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4467 "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4468 );
4469 });
4470 }
4471
4472 #[gpui::test]
4473 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4474 init_test_settings(cx);
4475
4476 let project = create_test_project(cx, json!({})).await;
4477 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4478
4479 // We'll use a wrapper to switch behavior after first failure
4480 struct RetryTestModel {
4481 inner: Arc<FakeLanguageModel>,
4482 failed_once: Arc<Mutex<bool>>,
4483 }
4484
4485 impl LanguageModel for RetryTestModel {
4486 fn id(&self) -> LanguageModelId {
4487 self.inner.id()
4488 }
4489
4490 fn name(&self) -> LanguageModelName {
4491 self.inner.name()
4492 }
4493
4494 fn provider_id(&self) -> LanguageModelProviderId {
4495 self.inner.provider_id()
4496 }
4497
4498 fn provider_name(&self) -> LanguageModelProviderName {
4499 self.inner.provider_name()
4500 }
4501
4502 fn supports_tools(&self) -> bool {
4503 self.inner.supports_tools()
4504 }
4505
4506 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4507 self.inner.supports_tool_choice(choice)
4508 }
4509
4510 fn supports_images(&self) -> bool {
4511 self.inner.supports_images()
4512 }
4513
4514 fn telemetry_id(&self) -> String {
4515 self.inner.telemetry_id()
4516 }
4517
4518 fn max_token_count(&self) -> u64 {
4519 self.inner.max_token_count()
4520 }
4521
4522 fn count_tokens(
4523 &self,
4524 request: LanguageModelRequest,
4525 cx: &App,
4526 ) -> BoxFuture<'static, Result<u64>> {
4527 self.inner.count_tokens(request, cx)
4528 }
4529
4530 fn stream_completion(
4531 &self,
4532 request: LanguageModelRequest,
4533 cx: &AsyncApp,
4534 ) -> BoxFuture<
4535 'static,
4536 Result<
4537 BoxStream<
4538 'static,
4539 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4540 >,
4541 LanguageModelCompletionError,
4542 >,
4543 > {
4544 if !*self.failed_once.lock() {
4545 *self.failed_once.lock() = true;
4546 let provider = self.provider_name();
4547 // Return error on first attempt
4548 let stream = futures::stream::once(async move {
4549 Err(LanguageModelCompletionError::ServerOverloaded {
4550 provider,
4551 retry_after: None,
4552 })
4553 });
4554 async move { Ok(stream.boxed()) }.boxed()
4555 } else {
4556 // Succeed on retry
4557 self.inner.stream_completion(request, cx)
4558 }
4559 }
4560
4561 fn as_fake(&self) -> &FakeLanguageModel {
4562 &self.inner
4563 }
4564 }
4565
4566 let model = Arc::new(RetryTestModel {
4567 inner: Arc::new(FakeLanguageModel::default()),
4568 failed_once: Arc::new(Mutex::new(false)),
4569 });
4570
4571 // Insert a user message
4572 thread.update(cx, |thread, cx| {
4573 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4574 });
4575
4576 // Track message deletions
4577 // Track when retry completes successfully
4578 let retry_completed = Arc::new(Mutex::new(false));
4579 let retry_completed_clone = retry_completed.clone();
4580
4581 let _subscription = thread.update(cx, |_, cx| {
4582 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4583 if let ThreadEvent::StreamedCompletion = event {
4584 *retry_completed_clone.lock() = true;
4585 }
4586 })
4587 });
4588
4589 // Start completion
4590 thread.update(cx, |thread, cx| {
4591 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4592 });
4593 cx.run_until_parked();
4594
4595 // Get the retry message ID
4596 let retry_message_id = thread.read_with(cx, |thread, _| {
4597 thread
4598 .messages()
4599 .find(|msg| msg.role == Role::System && msg.ui_only)
4600 .map(|msg| msg.id)
4601 .expect("Should have a retry message")
4602 });
4603
4604 // Wait for retry
4605 cx.executor().advance_clock(BASE_RETRY_DELAY);
4606 cx.run_until_parked();
4607
4608 // Stream some successful content
4609 let fake_model = model.as_fake();
4610 // After the retry, there should be a new pending completion
4611 let pending = fake_model.pending_completions();
4612 assert!(
4613 !pending.is_empty(),
4614 "Should have a pending completion after retry"
4615 );
4616 fake_model.stream_completion_response(&pending[0], "Success!");
4617 fake_model.end_completion_stream(&pending[0]);
4618 cx.run_until_parked();
4619
4620 // Check that the retry completed successfully
4621 assert!(
4622 *retry_completed.lock(),
4623 "Retry should have completed successfully"
4624 );
4625
4626 // Retry message should still exist but be marked as ui_only
4627 thread.read_with(cx, |thread, _| {
4628 let retry_msg = thread
4629 .message(retry_message_id)
4630 .expect("Retry message should still exist");
4631 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4632 assert_eq!(
4633 retry_msg.role,
4634 Role::System,
4635 "Retry message should have System role"
4636 );
4637 });
4638 }
4639
4640 #[gpui::test]
4641 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4642 init_test_settings(cx);
4643
4644 let project = create_test_project(cx, json!({})).await;
4645 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4646
4647 // Create a model that fails once then succeeds
4648 struct FailOnceModel {
4649 inner: Arc<FakeLanguageModel>,
4650 failed_once: Arc<Mutex<bool>>,
4651 }
4652
4653 impl LanguageModel for FailOnceModel {
4654 fn id(&self) -> LanguageModelId {
4655 self.inner.id()
4656 }
4657
4658 fn name(&self) -> LanguageModelName {
4659 self.inner.name()
4660 }
4661
4662 fn provider_id(&self) -> LanguageModelProviderId {
4663 self.inner.provider_id()
4664 }
4665
4666 fn provider_name(&self) -> LanguageModelProviderName {
4667 self.inner.provider_name()
4668 }
4669
4670 fn supports_tools(&self) -> bool {
4671 self.inner.supports_tools()
4672 }
4673
4674 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4675 self.inner.supports_tool_choice(choice)
4676 }
4677
4678 fn supports_images(&self) -> bool {
4679 self.inner.supports_images()
4680 }
4681
4682 fn telemetry_id(&self) -> String {
4683 self.inner.telemetry_id()
4684 }
4685
4686 fn max_token_count(&self) -> u64 {
4687 self.inner.max_token_count()
4688 }
4689
4690 fn count_tokens(
4691 &self,
4692 request: LanguageModelRequest,
4693 cx: &App,
4694 ) -> BoxFuture<'static, Result<u64>> {
4695 self.inner.count_tokens(request, cx)
4696 }
4697
4698 fn stream_completion(
4699 &self,
4700 request: LanguageModelRequest,
4701 cx: &AsyncApp,
4702 ) -> BoxFuture<
4703 'static,
4704 Result<
4705 BoxStream<
4706 'static,
4707 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4708 >,
4709 LanguageModelCompletionError,
4710 >,
4711 > {
4712 if !*self.failed_once.lock() {
4713 *self.failed_once.lock() = true;
4714 let provider = self.provider_name();
4715 // Return error on first attempt
4716 let stream = futures::stream::once(async move {
4717 Err(LanguageModelCompletionError::ServerOverloaded {
4718 provider,
4719 retry_after: None,
4720 })
4721 });
4722 async move { Ok(stream.boxed()) }.boxed()
4723 } else {
4724 // Succeed on retry
4725 self.inner.stream_completion(request, cx)
4726 }
4727 }
4728 }
4729
4730 let fail_once_model = Arc::new(FailOnceModel {
4731 inner: Arc::new(FakeLanguageModel::default()),
4732 failed_once: Arc::new(Mutex::new(false)),
4733 });
4734
4735 // Insert a user message
4736 thread.update(cx, |thread, cx| {
4737 thread.insert_user_message(
4738 "Test message",
4739 ContextLoadResult::default(),
4740 None,
4741 vec![],
4742 cx,
4743 );
4744 });
4745
4746 // Start completion with fail-once model
4747 thread.update(cx, |thread, cx| {
4748 thread.send_to_model(
4749 fail_once_model.clone(),
4750 CompletionIntent::UserPrompt,
4751 None,
4752 cx,
4753 );
4754 });
4755
4756 cx.run_until_parked();
4757
4758 // Verify retry state exists after first failure
4759 thread.read_with(cx, |thread, _| {
4760 assert!(
4761 thread.retry_state.is_some(),
4762 "Should have retry state after failure"
4763 );
4764 });
4765
4766 // Wait for retry delay
4767 cx.executor().advance_clock(BASE_RETRY_DELAY);
4768 cx.run_until_parked();
4769
4770 // The retry should now use our FailOnceModel which should succeed
4771 // We need to help the FakeLanguageModel complete the stream
4772 let inner_fake = fail_once_model.inner.clone();
4773
4774 // Wait a bit for the retry to start
4775 cx.run_until_parked();
4776
4777 // Check for pending completions and complete them
4778 if let Some(pending) = inner_fake.pending_completions().first() {
4779 inner_fake.stream_completion_response(pending, "Success!");
4780 inner_fake.end_completion_stream(pending);
4781 }
4782 cx.run_until_parked();
4783
4784 thread.read_with(cx, |thread, _| {
4785 assert!(
4786 thread.retry_state.is_none(),
4787 "Retry state should be cleared after successful completion"
4788 );
4789
4790 let has_assistant_message = thread
4791 .messages
4792 .iter()
4793 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4794 assert!(
4795 has_assistant_message,
4796 "Should have an assistant message after successful retry"
4797 );
4798 });
4799 }
4800
4801 #[gpui::test]
4802 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4803 init_test_settings(cx);
4804
4805 let project = create_test_project(cx, json!({})).await;
4806 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4807
4808 // Create a model that returns rate limit error with retry_after
4809 struct RateLimitModel {
4810 inner: Arc<FakeLanguageModel>,
4811 }
4812
4813 impl LanguageModel for RateLimitModel {
4814 fn id(&self) -> LanguageModelId {
4815 self.inner.id()
4816 }
4817
4818 fn name(&self) -> LanguageModelName {
4819 self.inner.name()
4820 }
4821
4822 fn provider_id(&self) -> LanguageModelProviderId {
4823 self.inner.provider_id()
4824 }
4825
4826 fn provider_name(&self) -> LanguageModelProviderName {
4827 self.inner.provider_name()
4828 }
4829
4830 fn supports_tools(&self) -> bool {
4831 self.inner.supports_tools()
4832 }
4833
4834 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4835 self.inner.supports_tool_choice(choice)
4836 }
4837
4838 fn supports_images(&self) -> bool {
4839 self.inner.supports_images()
4840 }
4841
4842 fn telemetry_id(&self) -> String {
4843 self.inner.telemetry_id()
4844 }
4845
4846 fn max_token_count(&self) -> u64 {
4847 self.inner.max_token_count()
4848 }
4849
4850 fn count_tokens(
4851 &self,
4852 request: LanguageModelRequest,
4853 cx: &App,
4854 ) -> BoxFuture<'static, Result<u64>> {
4855 self.inner.count_tokens(request, cx)
4856 }
4857
4858 fn stream_completion(
4859 &self,
4860 _request: LanguageModelRequest,
4861 _cx: &AsyncApp,
4862 ) -> BoxFuture<
4863 'static,
4864 Result<
4865 BoxStream<
4866 'static,
4867 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4868 >,
4869 LanguageModelCompletionError,
4870 >,
4871 > {
4872 let provider = self.provider_name();
4873 async move {
4874 let stream = futures::stream::once(async move {
4875 Err(LanguageModelCompletionError::RateLimitExceeded {
4876 provider,
4877 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4878 })
4879 });
4880 Ok(stream.boxed())
4881 }
4882 .boxed()
4883 }
4884
4885 fn as_fake(&self) -> &FakeLanguageModel {
4886 &self.inner
4887 }
4888 }
4889
4890 let model = Arc::new(RateLimitModel {
4891 inner: Arc::new(FakeLanguageModel::default()),
4892 });
4893
4894 // Insert a user message
4895 thread.update(cx, |thread, cx| {
4896 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4897 });
4898
4899 // Start completion
4900 thread.update(cx, |thread, cx| {
4901 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4902 });
4903
4904 cx.run_until_parked();
4905
4906 let retry_count = thread.update(cx, |thread, _| {
4907 thread
4908 .messages
4909 .iter()
4910 .filter(|m| {
4911 m.ui_only
4912 && m.segments.iter().any(|s| {
4913 if let MessageSegment::Text(text) = s {
4914 text.contains("rate limit exceeded")
4915 } else {
4916 false
4917 }
4918 })
4919 })
4920 .count()
4921 });
4922 assert_eq!(retry_count, 1, "Should have scheduled one retry");
4923
4924 thread.read_with(cx, |thread, _| {
4925 assert!(
4926 thread.retry_state.is_some(),
4927 "Rate limit errors should set retry_state"
4928 );
4929 if let Some(retry_state) = &thread.retry_state {
4930 assert_eq!(
4931 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4932 "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4933 );
4934 }
4935 });
4936
4937 // Verify we have one retry message
4938 thread.read_with(cx, |thread, _| {
4939 let retry_messages = thread
4940 .messages
4941 .iter()
4942 .filter(|msg| {
4943 msg.ui_only
4944 && msg.segments.iter().any(|seg| {
4945 if let MessageSegment::Text(text) = seg {
4946 text.contains("rate limit exceeded")
4947 } else {
4948 false
4949 }
4950 })
4951 })
4952 .count();
4953 assert_eq!(
4954 retry_messages, 1,
4955 "Should have one rate limit retry message"
4956 );
4957 });
4958
4959 // Check that retry message doesn't include attempt count
4960 thread.read_with(cx, |thread, _| {
4961 let retry_message = thread
4962 .messages
4963 .iter()
4964 .find(|msg| msg.role == Role::System && msg.ui_only)
4965 .expect("Should have a retry message");
4966
4967 // Check that the message contains attempt count since we use retry_state
4968 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
4969 assert!(
4970 text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
4971 "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
4972 );
4973 assert!(
4974 text.contains("Retrying"),
4975 "Rate limit retry message should contain retry text"
4976 );
4977 }
4978 });
4979 }
4980
4981 #[gpui::test]
4982 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
4983 init_test_settings(cx);
4984
4985 let project = create_test_project(cx, json!({})).await;
4986 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
4987
4988 // Insert a regular user message
4989 thread.update(cx, |thread, cx| {
4990 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4991 });
4992
4993 // Insert a UI-only message (like our retry notifications)
4994 thread.update(cx, |thread, cx| {
4995 let id = thread.next_message_id.post_inc();
4996 thread.messages.push(Message {
4997 id,
4998 role: Role::System,
4999 segments: vec![MessageSegment::Text(
5000 "This is a UI-only message that should not be sent to the model".to_string(),
5001 )],
5002 loaded_context: LoadedContext::default(),
5003 creases: Vec::new(),
5004 is_hidden: true,
5005 ui_only: true,
5006 });
5007 cx.emit(ThreadEvent::MessageAdded(id));
5008 });
5009
5010 // Insert another regular message
5011 thread.update(cx, |thread, cx| {
5012 thread.insert_user_message(
5013 "How are you?",
5014 ContextLoadResult::default(),
5015 None,
5016 vec![],
5017 cx,
5018 );
5019 });
5020
5021 // Generate the completion request
5022 let request = thread.update(cx, |thread, cx| {
5023 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5024 });
5025
5026 // Verify that the request only contains non-UI-only messages
5027 // Should have system prompt + 2 user messages, but not the UI-only message
5028 let user_messages: Vec<_> = request
5029 .messages
5030 .iter()
5031 .filter(|msg| msg.role == Role::User)
5032 .collect();
5033 assert_eq!(
5034 user_messages.len(),
5035 2,
5036 "Should have exactly 2 user messages"
5037 );
5038
5039 // Verify the UI-only content is not present anywhere in the request
5040 let request_text = request
5041 .messages
5042 .iter()
5043 .flat_map(|msg| &msg.content)
5044 .filter_map(|content| match content {
5045 MessageContent::Text(text) => Some(text.as_str()),
5046 _ => None,
5047 })
5048 .collect::<String>();
5049
5050 assert!(
5051 !request_text.contains("UI-only message"),
5052 "UI-only message content should not be in the request"
5053 );
5054
5055 // Verify the thread still has all 3 messages (including UI-only)
5056 thread.read_with(cx, |thread, _| {
5057 assert_eq!(
5058 thread.messages().count(),
5059 3,
5060 "Thread should have 3 messages"
5061 );
5062 assert_eq!(
5063 thread.messages().filter(|m| m.ui_only).count(),
5064 1,
5065 "Thread should have 1 UI-only message"
5066 );
5067 });
5068
5069 // Verify that UI-only messages are not serialized
5070 let serialized = thread
5071 .update(cx, |thread, cx| thread.serialize(cx))
5072 .await
5073 .unwrap();
5074 assert_eq!(
5075 serialized.messages.len(),
5076 2,
5077 "Serialized thread should only have 2 messages (no UI-only)"
5078 );
5079 }
5080
5081 #[gpui::test]
5082 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5083 init_test_settings(cx);
5084
5085 let project = create_test_project(cx, json!({})).await;
5086 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5087
5088 // Create model that returns overloaded error
5089 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5090
5091 // Insert a user message
5092 thread.update(cx, |thread, cx| {
5093 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5094 });
5095
5096 // Start completion
5097 thread.update(cx, |thread, cx| {
5098 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5099 });
5100
5101 cx.run_until_parked();
5102
5103 // Verify retry was scheduled by checking for retry message
5104 let has_retry_message = thread.read_with(cx, |thread, _| {
5105 thread.messages.iter().any(|m| {
5106 m.ui_only
5107 && m.segments.iter().any(|s| {
5108 if let MessageSegment::Text(text) = s {
5109 text.contains("Retrying") && text.contains("seconds")
5110 } else {
5111 false
5112 }
5113 })
5114 })
5115 });
5116 assert!(has_retry_message, "Should have scheduled a retry");
5117
5118 // Cancel the completion before the retry happens
5119 thread.update(cx, |thread, cx| {
5120 thread.cancel_last_completion(None, cx);
5121 });
5122
5123 cx.run_until_parked();
5124
5125 // The retry should not have happened - no pending completions
5126 let fake_model = model.as_fake();
5127 assert_eq!(
5128 fake_model.pending_completions().len(),
5129 0,
5130 "Should have no pending completions after cancellation"
5131 );
5132
5133 // Verify the retry was cancelled by checking retry state
5134 thread.read_with(cx, |thread, _| {
5135 if let Some(retry_state) = &thread.retry_state {
5136 panic!(
5137 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5138 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5139 );
5140 }
5141 });
5142 }
5143
5144 fn test_summarize_error(
5145 model: &Arc<dyn LanguageModel>,
5146 thread: &Entity<Thread>,
5147 cx: &mut TestAppContext,
5148 ) {
5149 thread.update(cx, |thread, cx| {
5150 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5151 thread.send_to_model(
5152 model.clone(),
5153 CompletionIntent::ThreadSummarization,
5154 None,
5155 cx,
5156 );
5157 });
5158
5159 let fake_model = model.as_fake();
5160 simulate_successful_response(&fake_model, cx);
5161
5162 thread.read_with(cx, |thread, _| {
5163 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5164 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5165 });
5166
5167 // Simulate summary request ending
5168 cx.run_until_parked();
5169 fake_model.end_last_completion_stream();
5170 cx.run_until_parked();
5171
5172 // State is set to Error and default message
5173 thread.read_with(cx, |thread, _| {
5174 assert!(matches!(thread.summary(), ThreadSummary::Error));
5175 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5176 });
5177 }
5178
5179 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5180 cx.run_until_parked();
5181 fake_model.stream_last_completion_response("Assistant response");
5182 fake_model.end_last_completion_stream();
5183 cx.run_until_parked();
5184 }
5185
5186 fn init_test_settings(cx: &mut TestAppContext) {
5187 cx.update(|cx| {
5188 let settings_store = SettingsStore::test(cx);
5189 cx.set_global(settings_store);
5190 language::init(cx);
5191 Project::init_settings(cx);
5192 AgentSettings::register(cx);
5193 prompt_store::init(cx);
5194 thread_store::init(cx);
5195 workspace::init_settings(cx);
5196 language_model::init_settings(cx);
5197 ThemeSettings::register(cx);
5198 ToolRegistry::default_global(cx);
5199 assistant_tool::init(cx);
5200
5201 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5202 http_client::FakeHttpClient::with_200_response(),
5203 "http://localhost".to_string(),
5204 None,
5205 ));
5206 assistant_tools::init(http_client, cx);
5207 });
5208 }
5209
5210 // Helper to create a test project with test files
5211 async fn create_test_project(
5212 cx: &mut TestAppContext,
5213 files: serde_json::Value,
5214 ) -> Entity<Project> {
5215 let fs = FakeFs::new(cx.executor());
5216 fs.insert_tree(path!("/test"), files).await;
5217 Project::test(fs, [path!("/test").as_ref()], cx).await
5218 }
5219
5220 async fn setup_test_environment(
5221 cx: &mut TestAppContext,
5222 project: Entity<Project>,
5223 ) -> (
5224 Entity<Workspace>,
5225 Entity<ThreadStore>,
5226 Entity<Thread>,
5227 Entity<ContextStore>,
5228 Arc<dyn LanguageModel>,
5229 ) {
5230 let (workspace, cx) =
5231 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5232
5233 let thread_store = cx
5234 .update(|_, cx| {
5235 ThreadStore::load(
5236 project.clone(),
5237 cx.new(|_| ToolWorkingSet::default()),
5238 None,
5239 Arc::new(PromptBuilder::new(None).unwrap()),
5240 cx,
5241 )
5242 })
5243 .await
5244 .unwrap();
5245
5246 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5247 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5248
5249 let provider = Arc::new(FakeLanguageModelProvider);
5250 let model = provider.test_model();
5251 let model: Arc<dyn LanguageModel> = Arc::new(model);
5252
5253 cx.update(|_, cx| {
5254 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5255 registry.set_default_model(
5256 Some(ConfiguredModel {
5257 provider: provider.clone(),
5258 model: model.clone(),
5259 }),
5260 cx,
5261 );
5262 registry.set_thread_summary_model(
5263 Some(ConfiguredModel {
5264 provider,
5265 model: model.clone(),
5266 }),
5267 cx,
5268 );
5269 })
5270 });
5271
5272 (workspace, thread_store, thread, context_store, model)
5273 }
5274
5275 async fn add_file_to_context(
5276 project: &Entity<Project>,
5277 context_store: &Entity<ContextStore>,
5278 path: &str,
5279 cx: &mut TestAppContext,
5280 ) -> Result<Entity<language::Buffer>> {
5281 let buffer_path = project
5282 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5283 .unwrap();
5284
5285 let buffer = project
5286 .update(cx, |project, cx| {
5287 project.open_buffer(buffer_path.clone(), cx)
5288 })
5289 .await
5290 .unwrap();
5291
5292 context_store.update(cx, |context_store, cx| {
5293 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5294 });
5295
5296 Ok(buffer)
5297 }
5298}