1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::HashMap;
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use http_client::StatusCode;
25use language_model::{
26 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
27 LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
28 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
29 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
30 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
31 TokenUsage,
32};
33use postage::stream::Stream as _;
34use project::{
35 Project,
36 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
37};
38use prompt_store::{ModelContext, PromptBuilder};
39use proto::Plan;
40use schemars::JsonSchema;
41use serde::{Deserialize, Serialize};
42use settings::Settings;
43use std::{
44 io::Write,
45 ops::Range,
46 sync::Arc,
47 time::{Duration, Instant},
48};
49use thiserror::Error;
50use util::{ResultExt as _, debug_panic, post_inc};
51use uuid::Uuid;
52use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
53
54const MAX_RETRY_ATTEMPTS: u8 = 3;
55const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
56
57#[derive(Debug, Clone)]
58enum RetryStrategy {
59 ExponentialBackoff {
60 initial_delay: Duration,
61 max_attempts: u8,
62 },
63 Fixed {
64 delay: Duration,
65 max_attempts: u8,
66 },
67}
68
69#[derive(
70 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
71)]
72pub struct ThreadId(Arc<str>);
73
74impl ThreadId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for ThreadId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86impl From<&str> for ThreadId {
87 fn from(value: &str) -> Self {
88 Self(value.into())
89 }
90}
91
92/// The ID of the user prompt that initiated a request.
93///
94/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
95#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
96pub struct PromptId(Arc<str>);
97
98impl PromptId {
99 pub fn new() -> Self {
100 Self(Uuid::new_v4().to_string().into())
101 }
102}
103
104impl std::fmt::Display for PromptId {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 write!(f, "{}", self.0)
107 }
108}
109
110#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
111pub struct MessageId(pub(crate) usize);
112
113impl MessageId {
114 fn post_inc(&mut self) -> Self {
115 Self(post_inc(&mut self.0))
116 }
117
118 pub fn as_usize(&self) -> usize {
119 self.0
120 }
121}
122
123/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
124#[derive(Clone, Debug)]
125pub struct MessageCrease {
126 pub range: Range<usize>,
127 pub icon_path: SharedString,
128 pub label: SharedString,
129 /// None for a deserialized message, Some otherwise.
130 pub context: Option<AgentContextHandle>,
131}
132
133/// A message in a [`Thread`].
134#[derive(Debug, Clone)]
135pub struct Message {
136 pub id: MessageId,
137 pub role: Role,
138 pub segments: Vec<MessageSegment>,
139 pub loaded_context: LoadedContext,
140 pub creases: Vec<MessageCrease>,
141 pub is_hidden: bool,
142 pub ui_only: bool,
143}
144
145impl Message {
146 /// Returns whether the message contains any meaningful text that should be displayed
147 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
148 pub fn should_display_content(&self) -> bool {
149 self.segments.iter().all(|segment| segment.should_display())
150 }
151
152 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
153 if let Some(MessageSegment::Thinking {
154 text: segment,
155 signature: current_signature,
156 }) = self.segments.last_mut()
157 {
158 if let Some(signature) = signature {
159 *current_signature = Some(signature);
160 }
161 segment.push_str(text);
162 } else {
163 self.segments.push(MessageSegment::Thinking {
164 text: text.to_string(),
165 signature,
166 });
167 }
168 }
169
170 pub fn push_redacted_thinking(&mut self, data: String) {
171 self.segments.push(MessageSegment::RedactedThinking(data));
172 }
173
174 pub fn push_text(&mut self, text: &str) {
175 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
176 segment.push_str(text);
177 } else {
178 self.segments.push(MessageSegment::Text(text.to_string()));
179 }
180 }
181
182 pub fn to_string(&self) -> String {
183 let mut result = String::new();
184
185 if !self.loaded_context.text.is_empty() {
186 result.push_str(&self.loaded_context.text);
187 }
188
189 for segment in &self.segments {
190 match segment {
191 MessageSegment::Text(text) => result.push_str(text),
192 MessageSegment::Thinking { text, .. } => {
193 result.push_str("<think>\n");
194 result.push_str(text);
195 result.push_str("\n</think>");
196 }
197 MessageSegment::RedactedThinking(_) => {}
198 }
199 }
200
201 result
202 }
203}
204
205#[derive(Debug, Clone, PartialEq, Eq)]
206pub enum MessageSegment {
207 Text(String),
208 Thinking {
209 text: String,
210 signature: Option<String>,
211 },
212 RedactedThinking(String),
213}
214
215impl MessageSegment {
216 pub fn should_display(&self) -> bool {
217 match self {
218 Self::Text(text) => text.is_empty(),
219 Self::Thinking { text, .. } => text.is_empty(),
220 Self::RedactedThinking(_) => false,
221 }
222 }
223
224 pub fn text(&self) -> Option<&str> {
225 match self {
226 MessageSegment::Text(text) => Some(text),
227 _ => None,
228 }
229 }
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
233pub struct ProjectSnapshot {
234 pub worktree_snapshots: Vec<WorktreeSnapshot>,
235 pub unsaved_buffer_paths: Vec<String>,
236 pub timestamp: DateTime<Utc>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
240pub struct WorktreeSnapshot {
241 pub worktree_path: String,
242 pub git_state: Option<GitState>,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
246pub struct GitState {
247 pub remote_url: Option<String>,
248 pub head_sha: Option<String>,
249 pub current_branch: Option<String>,
250 pub diff: Option<String>,
251}
252
253#[derive(Clone, Debug)]
254pub struct ThreadCheckpoint {
255 message_id: MessageId,
256 git_checkpoint: GitStoreCheckpoint,
257}
258
259#[derive(Copy, Clone, Debug, PartialEq, Eq)]
260pub enum ThreadFeedback {
261 Positive,
262 Negative,
263}
264
265pub enum LastRestoreCheckpoint {
266 Pending {
267 message_id: MessageId,
268 },
269 Error {
270 message_id: MessageId,
271 error: String,
272 },
273}
274
275impl LastRestoreCheckpoint {
276 pub fn message_id(&self) -> MessageId {
277 match self {
278 LastRestoreCheckpoint::Pending { message_id } => *message_id,
279 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
280 }
281 }
282}
283
284#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
285pub enum DetailedSummaryState {
286 #[default]
287 NotGenerated,
288 Generating {
289 message_id: MessageId,
290 },
291 Generated {
292 text: SharedString,
293 message_id: MessageId,
294 },
295}
296
297impl DetailedSummaryState {
298 fn text(&self) -> Option<SharedString> {
299 if let Self::Generated { text, .. } = self {
300 Some(text.clone())
301 } else {
302 None
303 }
304 }
305}
306
307#[derive(Default, Debug)]
308pub struct TotalTokenUsage {
309 pub total: u64,
310 pub max: u64,
311}
312
313impl TotalTokenUsage {
314 pub fn ratio(&self) -> TokenUsageRatio {
315 #[cfg(debug_assertions)]
316 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
317 .unwrap_or("0.8".to_string())
318 .parse()
319 .unwrap();
320 #[cfg(not(debug_assertions))]
321 let warning_threshold: f32 = 0.8;
322
323 // When the maximum is unknown because there is no selected model,
324 // avoid showing the token limit warning.
325 if self.max == 0 {
326 TokenUsageRatio::Normal
327 } else if self.total >= self.max {
328 TokenUsageRatio::Exceeded
329 } else if self.total as f32 / self.max as f32 >= warning_threshold {
330 TokenUsageRatio::Warning
331 } else {
332 TokenUsageRatio::Normal
333 }
334 }
335
336 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
337 TotalTokenUsage {
338 total: self.total + tokens,
339 max: self.max,
340 }
341 }
342}
343
344#[derive(Debug, Default, PartialEq, Eq)]
345pub enum TokenUsageRatio {
346 #[default]
347 Normal,
348 Warning,
349 Exceeded,
350}
351
352#[derive(Debug, Clone, Copy)]
353pub enum QueueState {
354 Sending,
355 Queued { position: usize },
356 Started,
357}
358
359/// A thread of conversation with the LLM.
360pub struct Thread {
361 id: ThreadId,
362 updated_at: DateTime<Utc>,
363 summary: ThreadSummary,
364 pending_summary: Task<Option<()>>,
365 detailed_summary_task: Task<Option<()>>,
366 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
367 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
368 completion_mode: agent_settings::CompletionMode,
369 messages: Vec<Message>,
370 next_message_id: MessageId,
371 last_prompt_id: PromptId,
372 project_context: SharedProjectContext,
373 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
374 completion_count: usize,
375 pending_completions: Vec<PendingCompletion>,
376 project: Entity<Project>,
377 prompt_builder: Arc<PromptBuilder>,
378 tools: Entity<ToolWorkingSet>,
379 tool_use: ToolUseState,
380 action_log: Entity<ActionLog>,
381 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
382 pending_checkpoint: Option<ThreadCheckpoint>,
383 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
384 request_token_usage: Vec<TokenUsage>,
385 cumulative_token_usage: TokenUsage,
386 exceeded_window_error: Option<ExceededWindowError>,
387 tool_use_limit_reached: bool,
388 feedback: Option<ThreadFeedback>,
389 retry_state: Option<RetryState>,
390 message_feedback: HashMap<MessageId, ThreadFeedback>,
391 last_auto_capture_at: Option<Instant>,
392 last_received_chunk_at: Option<Instant>,
393 request_callback: Option<
394 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
395 >,
396 remaining_turns: u32,
397 configured_model: Option<ConfiguredModel>,
398 profile: AgentProfile,
399}
400
401#[derive(Clone, Debug)]
402struct RetryState {
403 attempt: u8,
404 max_attempts: u8,
405 intent: CompletionIntent,
406}
407
408#[derive(Clone, Debug, PartialEq, Eq)]
409pub enum ThreadSummary {
410 Pending,
411 Generating,
412 Ready(SharedString),
413 Error,
414}
415
416impl ThreadSummary {
417 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
418
419 pub fn or_default(&self) -> SharedString {
420 self.unwrap_or(Self::DEFAULT)
421 }
422
423 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
424 self.ready().unwrap_or_else(|| message.into())
425 }
426
427 pub fn ready(&self) -> Option<SharedString> {
428 match self {
429 ThreadSummary::Ready(summary) => Some(summary.clone()),
430 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
436pub struct ExceededWindowError {
437 /// Model used when last message exceeded context window
438 model_id: LanguageModelId,
439 /// Token count including last message
440 token_count: u64,
441}
442
443impl Thread {
444 pub fn new(
445 project: Entity<Project>,
446 tools: Entity<ToolWorkingSet>,
447 prompt_builder: Arc<PromptBuilder>,
448 system_prompt: SharedProjectContext,
449 cx: &mut Context<Self>,
450 ) -> Self {
451 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
452 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
453 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
454
455 Self {
456 id: ThreadId::new(),
457 updated_at: Utc::now(),
458 summary: ThreadSummary::Pending,
459 pending_summary: Task::ready(None),
460 detailed_summary_task: Task::ready(None),
461 detailed_summary_tx,
462 detailed_summary_rx,
463 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
464 messages: Vec::new(),
465 next_message_id: MessageId(0),
466 last_prompt_id: PromptId::new(),
467 project_context: system_prompt,
468 checkpoints_by_message: HashMap::default(),
469 completion_count: 0,
470 pending_completions: Vec::new(),
471 project: project.clone(),
472 prompt_builder,
473 tools: tools.clone(),
474 last_restore_checkpoint: None,
475 pending_checkpoint: None,
476 tool_use: ToolUseState::new(tools.clone()),
477 action_log: cx.new(|_| ActionLog::new(project.clone())),
478 initial_project_snapshot: {
479 let project_snapshot = Self::project_snapshot(project, cx);
480 cx.foreground_executor()
481 .spawn(async move { Some(project_snapshot.await) })
482 .shared()
483 },
484 request_token_usage: Vec::new(),
485 cumulative_token_usage: TokenUsage::default(),
486 exceeded_window_error: None,
487 tool_use_limit_reached: false,
488 feedback: None,
489 retry_state: None,
490 message_feedback: HashMap::default(),
491 last_auto_capture_at: None,
492 last_received_chunk_at: None,
493 request_callback: None,
494 remaining_turns: u32::MAX,
495 configured_model,
496 profile: AgentProfile::new(profile_id, tools),
497 }
498 }
499
500 pub fn deserialize(
501 id: ThreadId,
502 serialized: SerializedThread,
503 project: Entity<Project>,
504 tools: Entity<ToolWorkingSet>,
505 prompt_builder: Arc<PromptBuilder>,
506 project_context: SharedProjectContext,
507 window: Option<&mut Window>, // None in headless mode
508 cx: &mut Context<Self>,
509 ) -> Self {
510 let next_message_id = MessageId(
511 serialized
512 .messages
513 .last()
514 .map(|message| message.id.0 + 1)
515 .unwrap_or(0),
516 );
517 let tool_use = ToolUseState::from_serialized_messages(
518 tools.clone(),
519 &serialized.messages,
520 project.clone(),
521 window,
522 cx,
523 );
524 let (detailed_summary_tx, detailed_summary_rx) =
525 postage::watch::channel_with(serialized.detailed_summary_state);
526
527 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
528 serialized
529 .model
530 .and_then(|model| {
531 let model = SelectedModel {
532 provider: model.provider.clone().into(),
533 model: model.model.clone().into(),
534 };
535 registry.select_model(&model, cx)
536 })
537 .or_else(|| registry.default_model())
538 });
539
540 let completion_mode = serialized
541 .completion_mode
542 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
543 let profile_id = serialized
544 .profile
545 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
546
547 Self {
548 id,
549 updated_at: serialized.updated_at,
550 summary: ThreadSummary::Ready(serialized.summary),
551 pending_summary: Task::ready(None),
552 detailed_summary_task: Task::ready(None),
553 detailed_summary_tx,
554 detailed_summary_rx,
555 completion_mode,
556 retry_state: None,
557 messages: serialized
558 .messages
559 .into_iter()
560 .map(|message| Message {
561 id: message.id,
562 role: message.role,
563 segments: message
564 .segments
565 .into_iter()
566 .map(|segment| match segment {
567 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
568 SerializedMessageSegment::Thinking { text, signature } => {
569 MessageSegment::Thinking { text, signature }
570 }
571 SerializedMessageSegment::RedactedThinking { data } => {
572 MessageSegment::RedactedThinking(data)
573 }
574 })
575 .collect(),
576 loaded_context: LoadedContext {
577 contexts: Vec::new(),
578 text: message.context,
579 images: Vec::new(),
580 },
581 creases: message
582 .creases
583 .into_iter()
584 .map(|crease| MessageCrease {
585 range: crease.start..crease.end,
586 icon_path: crease.icon_path,
587 label: crease.label,
588 context: None,
589 })
590 .collect(),
591 is_hidden: message.is_hidden,
592 ui_only: false, // UI-only messages are not persisted
593 })
594 .collect(),
595 next_message_id,
596 last_prompt_id: PromptId::new(),
597 project_context,
598 checkpoints_by_message: HashMap::default(),
599 completion_count: 0,
600 pending_completions: Vec::new(),
601 last_restore_checkpoint: None,
602 pending_checkpoint: None,
603 project: project.clone(),
604 prompt_builder,
605 tools: tools.clone(),
606 tool_use,
607 action_log: cx.new(|_| ActionLog::new(project)),
608 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
609 request_token_usage: serialized.request_token_usage,
610 cumulative_token_usage: serialized.cumulative_token_usage,
611 exceeded_window_error: None,
612 tool_use_limit_reached: serialized.tool_use_limit_reached,
613 feedback: None,
614 message_feedback: HashMap::default(),
615 last_auto_capture_at: None,
616 last_received_chunk_at: None,
617 request_callback: None,
618 remaining_turns: u32::MAX,
619 configured_model,
620 profile: AgentProfile::new(profile_id, tools),
621 }
622 }
623
624 pub fn set_request_callback(
625 &mut self,
626 callback: impl 'static
627 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
628 ) {
629 self.request_callback = Some(Box::new(callback));
630 }
631
632 pub fn id(&self) -> &ThreadId {
633 &self.id
634 }
635
636 pub fn profile(&self) -> &AgentProfile {
637 &self.profile
638 }
639
640 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
641 if &id != self.profile.id() {
642 self.profile = AgentProfile::new(id, self.tools.clone());
643 cx.emit(ThreadEvent::ProfileChanged);
644 }
645 }
646
647 pub fn is_empty(&self) -> bool {
648 self.messages.is_empty()
649 }
650
651 pub fn updated_at(&self) -> DateTime<Utc> {
652 self.updated_at
653 }
654
655 pub fn touch_updated_at(&mut self) {
656 self.updated_at = Utc::now();
657 }
658
659 pub fn advance_prompt_id(&mut self) {
660 self.last_prompt_id = PromptId::new();
661 }
662
663 pub fn project_context(&self) -> SharedProjectContext {
664 self.project_context.clone()
665 }
666
667 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
668 if self.configured_model.is_none() {
669 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
670 }
671 self.configured_model.clone()
672 }
673
674 pub fn configured_model(&self) -> Option<ConfiguredModel> {
675 self.configured_model.clone()
676 }
677
678 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
679 self.configured_model = model;
680 cx.notify();
681 }
682
683 pub fn summary(&self) -> &ThreadSummary {
684 &self.summary
685 }
686
687 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
688 let current_summary = match &self.summary {
689 ThreadSummary::Pending | ThreadSummary::Generating => return,
690 ThreadSummary::Ready(summary) => summary,
691 ThreadSummary::Error => &ThreadSummary::DEFAULT,
692 };
693
694 let mut new_summary = new_summary.into();
695
696 if new_summary.is_empty() {
697 new_summary = ThreadSummary::DEFAULT;
698 }
699
700 if current_summary != &new_summary {
701 self.summary = ThreadSummary::Ready(new_summary);
702 cx.emit(ThreadEvent::SummaryChanged);
703 }
704 }
705
706 pub fn completion_mode(&self) -> CompletionMode {
707 self.completion_mode
708 }
709
710 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
711 self.completion_mode = mode;
712 }
713
714 pub fn message(&self, id: MessageId) -> Option<&Message> {
715 let index = self
716 .messages
717 .binary_search_by(|message| message.id.cmp(&id))
718 .ok()?;
719
720 self.messages.get(index)
721 }
722
723 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
724 self.messages.iter()
725 }
726
727 pub fn is_generating(&self) -> bool {
728 !self.pending_completions.is_empty() || !self.all_tools_finished()
729 }
730
731 /// Indicates whether streaming of language model events is stale.
732 /// When `is_generating()` is false, this method returns `None`.
733 pub fn is_generation_stale(&self) -> Option<bool> {
734 const STALE_THRESHOLD: u128 = 250;
735
736 self.last_received_chunk_at
737 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
738 }
739
740 fn received_chunk(&mut self) {
741 self.last_received_chunk_at = Some(Instant::now());
742 }
743
744 pub fn queue_state(&self) -> Option<QueueState> {
745 self.pending_completions
746 .first()
747 .map(|pending_completion| pending_completion.queue_state)
748 }
749
750 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
751 &self.tools
752 }
753
754 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
755 self.tool_use
756 .pending_tool_uses()
757 .into_iter()
758 .find(|tool_use| &tool_use.id == id)
759 }
760
761 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
762 self.tool_use
763 .pending_tool_uses()
764 .into_iter()
765 .filter(|tool_use| tool_use.status.needs_confirmation())
766 }
767
768 pub fn has_pending_tool_uses(&self) -> bool {
769 !self.tool_use.pending_tool_uses().is_empty()
770 }
771
772 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
773 self.checkpoints_by_message.get(&id).cloned()
774 }
775
776 pub fn restore_checkpoint(
777 &mut self,
778 checkpoint: ThreadCheckpoint,
779 cx: &mut Context<Self>,
780 ) -> Task<Result<()>> {
781 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
782 message_id: checkpoint.message_id,
783 });
784 cx.emit(ThreadEvent::CheckpointChanged);
785 cx.notify();
786
787 let git_store = self.project().read(cx).git_store().clone();
788 let restore = git_store.update(cx, |git_store, cx| {
789 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
790 });
791
792 cx.spawn(async move |this, cx| {
793 let result = restore.await;
794 this.update(cx, |this, cx| {
795 if let Err(err) = result.as_ref() {
796 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
797 message_id: checkpoint.message_id,
798 error: err.to_string(),
799 });
800 } else {
801 this.truncate(checkpoint.message_id, cx);
802 this.last_restore_checkpoint = None;
803 }
804 this.pending_checkpoint = None;
805 cx.emit(ThreadEvent::CheckpointChanged);
806 cx.notify();
807 })?;
808 result
809 })
810 }
811
812 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
813 let pending_checkpoint = if self.is_generating() {
814 return;
815 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
816 checkpoint
817 } else {
818 return;
819 };
820
821 self.finalize_checkpoint(pending_checkpoint, cx);
822 }
823
824 fn finalize_checkpoint(
825 &mut self,
826 pending_checkpoint: ThreadCheckpoint,
827 cx: &mut Context<Self>,
828 ) {
829 let git_store = self.project.read(cx).git_store().clone();
830 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
831 cx.spawn(async move |this, cx| match final_checkpoint.await {
832 Ok(final_checkpoint) => {
833 let equal = git_store
834 .update(cx, |store, cx| {
835 store.compare_checkpoints(
836 pending_checkpoint.git_checkpoint.clone(),
837 final_checkpoint.clone(),
838 cx,
839 )
840 })?
841 .await
842 .unwrap_or(false);
843
844 if !equal {
845 this.update(cx, |this, cx| {
846 this.insert_checkpoint(pending_checkpoint, cx)
847 })?;
848 }
849
850 Ok(())
851 }
852 Err(_) => this.update(cx, |this, cx| {
853 this.insert_checkpoint(pending_checkpoint, cx)
854 }),
855 })
856 .detach();
857 }
858
859 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
860 self.checkpoints_by_message
861 .insert(checkpoint.message_id, checkpoint);
862 cx.emit(ThreadEvent::CheckpointChanged);
863 cx.notify();
864 }
865
866 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
867 self.last_restore_checkpoint.as_ref()
868 }
869
870 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
871 let Some(message_ix) = self
872 .messages
873 .iter()
874 .rposition(|message| message.id == message_id)
875 else {
876 return;
877 };
878 for deleted_message in self.messages.drain(message_ix..) {
879 self.checkpoints_by_message.remove(&deleted_message.id);
880 }
881 cx.notify();
882 }
883
884 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
885 self.messages
886 .iter()
887 .find(|message| message.id == id)
888 .into_iter()
889 .flat_map(|message| message.loaded_context.contexts.iter())
890 }
891
892 pub fn is_turn_end(&self, ix: usize) -> bool {
893 if self.messages.is_empty() {
894 return false;
895 }
896
897 if !self.is_generating() && ix == self.messages.len() - 1 {
898 return true;
899 }
900
901 let Some(message) = self.messages.get(ix) else {
902 return false;
903 };
904
905 if message.role != Role::Assistant {
906 return false;
907 }
908
909 self.messages
910 .get(ix + 1)
911 .and_then(|message| {
912 self.message(message.id)
913 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
914 })
915 .unwrap_or(false)
916 }
917
918 pub fn tool_use_limit_reached(&self) -> bool {
919 self.tool_use_limit_reached
920 }
921
922 /// Returns whether all of the tool uses have finished running.
923 pub fn all_tools_finished(&self) -> bool {
924 // If the only pending tool uses left are the ones with errors, then
925 // that means that we've finished running all of the pending tools.
926 self.tool_use
927 .pending_tool_uses()
928 .iter()
929 .all(|pending_tool_use| pending_tool_use.status.is_error())
930 }
931
932 /// Returns whether any pending tool uses may perform edits
933 pub fn has_pending_edit_tool_uses(&self) -> bool {
934 self.tool_use
935 .pending_tool_uses()
936 .iter()
937 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
938 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
939 }
940
941 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
942 self.tool_use.tool_uses_for_message(id, cx)
943 }
944
945 pub fn tool_results_for_message(
946 &self,
947 assistant_message_id: MessageId,
948 ) -> Vec<&LanguageModelToolResult> {
949 self.tool_use.tool_results_for_message(assistant_message_id)
950 }
951
952 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
953 self.tool_use.tool_result(id)
954 }
955
956 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
957 match &self.tool_use.tool_result(id)?.content {
958 LanguageModelToolResultContent::Text(text) => Some(text),
959 LanguageModelToolResultContent::Image(_) => {
960 // TODO: We should display image
961 None
962 }
963 }
964 }
965
966 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
967 self.tool_use.tool_result_card(id).cloned()
968 }
969
970 /// Return tools that are both enabled and supported by the model
971 pub fn available_tools(
972 &self,
973 cx: &App,
974 model: Arc<dyn LanguageModel>,
975 ) -> Vec<LanguageModelRequestTool> {
976 if model.supports_tools() {
977 self.profile
978 .enabled_tools(cx)
979 .into_iter()
980 .filter_map(|(name, tool)| {
981 // Skip tools that cannot be supported
982 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
983 Some(LanguageModelRequestTool {
984 name: name.into(),
985 description: tool.description(),
986 input_schema,
987 })
988 })
989 .collect()
990 } else {
991 Vec::default()
992 }
993 }
994
995 pub fn insert_user_message(
996 &mut self,
997 text: impl Into<String>,
998 loaded_context: ContextLoadResult,
999 git_checkpoint: Option<GitStoreCheckpoint>,
1000 creases: Vec<MessageCrease>,
1001 cx: &mut Context<Self>,
1002 ) -> MessageId {
1003 if !loaded_context.referenced_buffers.is_empty() {
1004 self.action_log.update(cx, |log, cx| {
1005 for buffer in loaded_context.referenced_buffers {
1006 log.buffer_read(buffer, cx);
1007 }
1008 });
1009 }
1010
1011 let message_id = self.insert_message(
1012 Role::User,
1013 vec![MessageSegment::Text(text.into())],
1014 loaded_context.loaded_context,
1015 creases,
1016 false,
1017 cx,
1018 );
1019
1020 if let Some(git_checkpoint) = git_checkpoint {
1021 self.pending_checkpoint = Some(ThreadCheckpoint {
1022 message_id,
1023 git_checkpoint,
1024 });
1025 }
1026
1027 self.auto_capture_telemetry(cx);
1028
1029 message_id
1030 }
1031
1032 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1033 let id = self.insert_message(
1034 Role::User,
1035 vec![MessageSegment::Text("Continue where you left off".into())],
1036 LoadedContext::default(),
1037 vec![],
1038 true,
1039 cx,
1040 );
1041 self.pending_checkpoint = None;
1042
1043 id
1044 }
1045
1046 pub fn insert_assistant_message(
1047 &mut self,
1048 segments: Vec<MessageSegment>,
1049 cx: &mut Context<Self>,
1050 ) -> MessageId {
1051 self.insert_message(
1052 Role::Assistant,
1053 segments,
1054 LoadedContext::default(),
1055 Vec::new(),
1056 false,
1057 cx,
1058 )
1059 }
1060
1061 pub fn insert_message(
1062 &mut self,
1063 role: Role,
1064 segments: Vec<MessageSegment>,
1065 loaded_context: LoadedContext,
1066 creases: Vec<MessageCrease>,
1067 is_hidden: bool,
1068 cx: &mut Context<Self>,
1069 ) -> MessageId {
1070 let id = self.next_message_id.post_inc();
1071 self.messages.push(Message {
1072 id,
1073 role,
1074 segments,
1075 loaded_context,
1076 creases,
1077 is_hidden,
1078 ui_only: false,
1079 });
1080 self.touch_updated_at();
1081 cx.emit(ThreadEvent::MessageAdded(id));
1082 id
1083 }
1084
1085 pub fn edit_message(
1086 &mut self,
1087 id: MessageId,
1088 new_role: Role,
1089 new_segments: Vec<MessageSegment>,
1090 creases: Vec<MessageCrease>,
1091 loaded_context: Option<LoadedContext>,
1092 checkpoint: Option<GitStoreCheckpoint>,
1093 cx: &mut Context<Self>,
1094 ) -> bool {
1095 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1096 return false;
1097 };
1098 message.role = new_role;
1099 message.segments = new_segments;
1100 message.creases = creases;
1101 if let Some(context) = loaded_context {
1102 message.loaded_context = context;
1103 }
1104 if let Some(git_checkpoint) = checkpoint {
1105 self.checkpoints_by_message.insert(
1106 id,
1107 ThreadCheckpoint {
1108 message_id: id,
1109 git_checkpoint,
1110 },
1111 );
1112 }
1113 self.touch_updated_at();
1114 cx.emit(ThreadEvent::MessageEdited(id));
1115 true
1116 }
1117
1118 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1119 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1120 return false;
1121 };
1122 self.messages.remove(index);
1123 self.touch_updated_at();
1124 cx.emit(ThreadEvent::MessageDeleted(id));
1125 true
1126 }
1127
1128 /// Returns the representation of this [`Thread`] in a textual form.
1129 ///
1130 /// This is the representation we use when attaching a thread as context to another thread.
1131 pub fn text(&self) -> String {
1132 let mut text = String::new();
1133
1134 for message in &self.messages {
1135 text.push_str(match message.role {
1136 language_model::Role::User => "User:",
1137 language_model::Role::Assistant => "Agent:",
1138 language_model::Role::System => "System:",
1139 });
1140 text.push('\n');
1141
1142 for segment in &message.segments {
1143 match segment {
1144 MessageSegment::Text(content) => text.push_str(content),
1145 MessageSegment::Thinking { text: content, .. } => {
1146 text.push_str(&format!("<think>{}</think>", content))
1147 }
1148 MessageSegment::RedactedThinking(_) => {}
1149 }
1150 }
1151 text.push('\n');
1152 }
1153
1154 text
1155 }
1156
1157 /// Serializes this thread into a format for storage or telemetry.
1158 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1159 let initial_project_snapshot = self.initial_project_snapshot.clone();
1160 cx.spawn(async move |this, cx| {
1161 let initial_project_snapshot = initial_project_snapshot.await;
1162 this.read_with(cx, |this, cx| SerializedThread {
1163 version: SerializedThread::VERSION.to_string(),
1164 summary: this.summary().or_default(),
1165 updated_at: this.updated_at(),
1166 messages: this
1167 .messages()
1168 .filter(|message| !message.ui_only)
1169 .map(|message| SerializedMessage {
1170 id: message.id,
1171 role: message.role,
1172 segments: message
1173 .segments
1174 .iter()
1175 .map(|segment| match segment {
1176 MessageSegment::Text(text) => {
1177 SerializedMessageSegment::Text { text: text.clone() }
1178 }
1179 MessageSegment::Thinking { text, signature } => {
1180 SerializedMessageSegment::Thinking {
1181 text: text.clone(),
1182 signature: signature.clone(),
1183 }
1184 }
1185 MessageSegment::RedactedThinking(data) => {
1186 SerializedMessageSegment::RedactedThinking {
1187 data: data.clone(),
1188 }
1189 }
1190 })
1191 .collect(),
1192 tool_uses: this
1193 .tool_uses_for_message(message.id, cx)
1194 .into_iter()
1195 .map(|tool_use| SerializedToolUse {
1196 id: tool_use.id,
1197 name: tool_use.name,
1198 input: tool_use.input,
1199 })
1200 .collect(),
1201 tool_results: this
1202 .tool_results_for_message(message.id)
1203 .into_iter()
1204 .map(|tool_result| SerializedToolResult {
1205 tool_use_id: tool_result.tool_use_id.clone(),
1206 is_error: tool_result.is_error,
1207 content: tool_result.content.clone(),
1208 output: tool_result.output.clone(),
1209 })
1210 .collect(),
1211 context: message.loaded_context.text.clone(),
1212 creases: message
1213 .creases
1214 .iter()
1215 .map(|crease| SerializedCrease {
1216 start: crease.range.start,
1217 end: crease.range.end,
1218 icon_path: crease.icon_path.clone(),
1219 label: crease.label.clone(),
1220 })
1221 .collect(),
1222 is_hidden: message.is_hidden,
1223 })
1224 .collect(),
1225 initial_project_snapshot,
1226 cumulative_token_usage: this.cumulative_token_usage,
1227 request_token_usage: this.request_token_usage.clone(),
1228 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1229 exceeded_window_error: this.exceeded_window_error.clone(),
1230 model: this
1231 .configured_model
1232 .as_ref()
1233 .map(|model| SerializedLanguageModel {
1234 provider: model.provider.id().0.to_string(),
1235 model: model.model.id().0.to_string(),
1236 }),
1237 completion_mode: Some(this.completion_mode),
1238 tool_use_limit_reached: this.tool_use_limit_reached,
1239 profile: Some(this.profile.id().clone()),
1240 })
1241 })
1242 }
1243
1244 pub fn remaining_turns(&self) -> u32 {
1245 self.remaining_turns
1246 }
1247
1248 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1249 self.remaining_turns = remaining_turns;
1250 }
1251
1252 pub fn send_to_model(
1253 &mut self,
1254 model: Arc<dyn LanguageModel>,
1255 intent: CompletionIntent,
1256 window: Option<AnyWindowHandle>,
1257 cx: &mut Context<Self>,
1258 ) {
1259 if self.remaining_turns == 0 {
1260 return;
1261 }
1262
1263 self.remaining_turns -= 1;
1264
1265 self.flush_notifications(model.clone(), intent, cx);
1266
1267 let request = self.to_completion_request(model.clone(), intent, cx);
1268
1269 self.stream_completion(request, model, intent, window, cx);
1270 }
1271
1272 pub fn used_tools_since_last_user_message(&self) -> bool {
1273 for message in self.messages.iter().rev() {
1274 if self.tool_use.message_has_tool_results(message.id) {
1275 return true;
1276 } else if message.role == Role::User {
1277 return false;
1278 }
1279 }
1280
1281 false
1282 }
1283
1284 pub fn to_completion_request(
1285 &self,
1286 model: Arc<dyn LanguageModel>,
1287 intent: CompletionIntent,
1288 cx: &mut Context<Self>,
1289 ) -> LanguageModelRequest {
1290 let mut request = LanguageModelRequest {
1291 thread_id: Some(self.id.to_string()),
1292 prompt_id: Some(self.last_prompt_id.to_string()),
1293 intent: Some(intent),
1294 mode: None,
1295 messages: vec![],
1296 tools: Vec::new(),
1297 tool_choice: None,
1298 stop: Vec::new(),
1299 temperature: AgentSettings::temperature_for_model(&model, cx),
1300 };
1301
1302 let available_tools = self.available_tools(cx, model.clone());
1303 let available_tool_names = available_tools
1304 .iter()
1305 .map(|tool| tool.name.clone())
1306 .collect();
1307
1308 let model_context = &ModelContext {
1309 available_tools: available_tool_names,
1310 };
1311
1312 if let Some(project_context) = self.project_context.borrow().as_ref() {
1313 match self
1314 .prompt_builder
1315 .generate_assistant_system_prompt(project_context, model_context)
1316 {
1317 Err(err) => {
1318 let message = format!("{err:?}").into();
1319 log::error!("{message}");
1320 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1321 header: "Error generating system prompt".into(),
1322 message,
1323 }));
1324 }
1325 Ok(system_prompt) => {
1326 request.messages.push(LanguageModelRequestMessage {
1327 role: Role::System,
1328 content: vec![MessageContent::Text(system_prompt)],
1329 cache: true,
1330 });
1331 }
1332 }
1333 } else {
1334 let message = "Context for system prompt unexpectedly not ready.".into();
1335 log::error!("{message}");
1336 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1337 header: "Error generating system prompt".into(),
1338 message,
1339 }));
1340 }
1341
1342 let mut message_ix_to_cache = None;
1343 for message in &self.messages {
1344 // ui_only messages are for the UI only, not for the model
1345 if message.ui_only {
1346 continue;
1347 }
1348
1349 let mut request_message = LanguageModelRequestMessage {
1350 role: message.role,
1351 content: Vec::new(),
1352 cache: false,
1353 };
1354
1355 message
1356 .loaded_context
1357 .add_to_request_message(&mut request_message);
1358
1359 for segment in &message.segments {
1360 match segment {
1361 MessageSegment::Text(text) => {
1362 let text = text.trim_end();
1363 if !text.is_empty() {
1364 request_message
1365 .content
1366 .push(MessageContent::Text(text.into()));
1367 }
1368 }
1369 MessageSegment::Thinking { text, signature } => {
1370 if !text.is_empty() {
1371 request_message.content.push(MessageContent::Thinking {
1372 text: text.into(),
1373 signature: signature.clone(),
1374 });
1375 }
1376 }
1377 MessageSegment::RedactedThinking(data) => {
1378 request_message
1379 .content
1380 .push(MessageContent::RedactedThinking(data.clone()));
1381 }
1382 };
1383 }
1384
1385 let mut cache_message = true;
1386 let mut tool_results_message = LanguageModelRequestMessage {
1387 role: Role::User,
1388 content: Vec::new(),
1389 cache: false,
1390 };
1391 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1392 if let Some(tool_result) = tool_result {
1393 request_message
1394 .content
1395 .push(MessageContent::ToolUse(tool_use.clone()));
1396 tool_results_message
1397 .content
1398 .push(MessageContent::ToolResult(LanguageModelToolResult {
1399 tool_use_id: tool_use.id.clone(),
1400 tool_name: tool_result.tool_name.clone(),
1401 is_error: tool_result.is_error,
1402 content: if tool_result.content.is_empty() {
1403 // Surprisingly, the API fails if we return an empty string here.
1404 // It thinks we are sending a tool use without a tool result.
1405 "<Tool returned an empty string>".into()
1406 } else {
1407 tool_result.content.clone()
1408 },
1409 output: None,
1410 }));
1411 } else {
1412 cache_message = false;
1413 log::debug!(
1414 "skipped tool use {:?} because it is still pending",
1415 tool_use
1416 );
1417 }
1418 }
1419
1420 if cache_message {
1421 message_ix_to_cache = Some(request.messages.len());
1422 }
1423 request.messages.push(request_message);
1424
1425 if !tool_results_message.content.is_empty() {
1426 if cache_message {
1427 message_ix_to_cache = Some(request.messages.len());
1428 }
1429 request.messages.push(tool_results_message);
1430 }
1431 }
1432
1433 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1434 if let Some(message_ix_to_cache) = message_ix_to_cache {
1435 request.messages[message_ix_to_cache].cache = true;
1436 }
1437
1438 request.tools = available_tools;
1439 request.mode = if model.supports_burn_mode() {
1440 Some(self.completion_mode.into())
1441 } else {
1442 Some(CompletionMode::Normal.into())
1443 };
1444
1445 request
1446 }
1447
1448 fn to_summarize_request(
1449 &self,
1450 model: &Arc<dyn LanguageModel>,
1451 intent: CompletionIntent,
1452 added_user_message: String,
1453 cx: &App,
1454 ) -> LanguageModelRequest {
1455 let mut request = LanguageModelRequest {
1456 thread_id: None,
1457 prompt_id: None,
1458 intent: Some(intent),
1459 mode: None,
1460 messages: vec![],
1461 tools: Vec::new(),
1462 tool_choice: None,
1463 stop: Vec::new(),
1464 temperature: AgentSettings::temperature_for_model(model, cx),
1465 };
1466
1467 for message in &self.messages {
1468 let mut request_message = LanguageModelRequestMessage {
1469 role: message.role,
1470 content: Vec::new(),
1471 cache: false,
1472 };
1473
1474 for segment in &message.segments {
1475 match segment {
1476 MessageSegment::Text(text) => request_message
1477 .content
1478 .push(MessageContent::Text(text.clone())),
1479 MessageSegment::Thinking { .. } => {}
1480 MessageSegment::RedactedThinking(_) => {}
1481 }
1482 }
1483
1484 if request_message.content.is_empty() {
1485 continue;
1486 }
1487
1488 request.messages.push(request_message);
1489 }
1490
1491 request.messages.push(LanguageModelRequestMessage {
1492 role: Role::User,
1493 content: vec![MessageContent::Text(added_user_message)],
1494 cache: false,
1495 });
1496
1497 request
1498 }
1499
1500 /// Insert auto-generated notifications (if any) to the thread
1501 fn flush_notifications(
1502 &mut self,
1503 model: Arc<dyn LanguageModel>,
1504 intent: CompletionIntent,
1505 cx: &mut Context<Self>,
1506 ) {
1507 match intent {
1508 CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1509 if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1510 cx.emit(ThreadEvent::ToolFinished {
1511 tool_use_id: pending_tool_use.id.clone(),
1512 pending_tool_use: Some(pending_tool_use),
1513 });
1514 }
1515 }
1516 CompletionIntent::ThreadSummarization
1517 | CompletionIntent::ThreadContextSummarization
1518 | CompletionIntent::CreateFile
1519 | CompletionIntent::EditFile
1520 | CompletionIntent::InlineAssist
1521 | CompletionIntent::TerminalInlineAssist
1522 | CompletionIntent::GenerateGitCommitMessage => {}
1523 };
1524 }
1525
1526 fn attach_tracked_files_state(
1527 &mut self,
1528 model: Arc<dyn LanguageModel>,
1529 cx: &mut App,
1530 ) -> Option<PendingToolUse> {
1531 let action_log = self.action_log.read(cx);
1532
1533 action_log.unnotified_stale_buffers(cx).next()?;
1534
1535 // Represent notification as a simulated `project_notifications` tool call
1536 let tool_name = Arc::from("project_notifications");
1537 let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1538 debug_panic!("`project_notifications` tool not found");
1539 return None;
1540 };
1541
1542 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1543 return None;
1544 }
1545
1546 let input = serde_json::json!({});
1547 let request = Arc::new(LanguageModelRequest::default()); // unused
1548 let window = None;
1549 let tool_result = tool.run(
1550 input,
1551 request,
1552 self.project.clone(),
1553 self.action_log.clone(),
1554 model.clone(),
1555 window,
1556 cx,
1557 );
1558
1559 let tool_use_id =
1560 LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1561
1562 let tool_use = LanguageModelToolUse {
1563 id: tool_use_id.clone(),
1564 name: tool_name.clone(),
1565 raw_input: "{}".to_string(),
1566 input: serde_json::json!({}),
1567 is_input_complete: true,
1568 };
1569
1570 let tool_output = cx.background_executor().block(tool_result.output);
1571
1572 // Attach a project_notification tool call to the latest existing
1573 // Assistant message. We cannot create a new Assistant message
1574 // because thinking models require a `thinking` block that we
1575 // cannot mock. We cannot send a notification as a normal
1576 // (non-tool-use) User message because this distracts Agent
1577 // too much.
1578 let tool_message_id = self
1579 .messages
1580 .iter()
1581 .enumerate()
1582 .rfind(|(_, message)| message.role == Role::Assistant)
1583 .map(|(_, message)| message.id)?;
1584
1585 let tool_use_metadata = ToolUseMetadata {
1586 model: model.clone(),
1587 thread_id: self.id.clone(),
1588 prompt_id: self.last_prompt_id.clone(),
1589 };
1590
1591 self.tool_use
1592 .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1593
1594 let pending_tool_use = self.tool_use.insert_tool_output(
1595 tool_use_id.clone(),
1596 tool_name,
1597 tool_output,
1598 self.configured_model.as_ref(),
1599 self.completion_mode,
1600 );
1601
1602 pending_tool_use
1603 }
1604
1605 pub fn stream_completion(
1606 &mut self,
1607 request: LanguageModelRequest,
1608 model: Arc<dyn LanguageModel>,
1609 intent: CompletionIntent,
1610 window: Option<AnyWindowHandle>,
1611 cx: &mut Context<Self>,
1612 ) {
1613 self.tool_use_limit_reached = false;
1614
1615 let pending_completion_id = post_inc(&mut self.completion_count);
1616 let mut request_callback_parameters = if self.request_callback.is_some() {
1617 Some((request.clone(), Vec::new()))
1618 } else {
1619 None
1620 };
1621 let prompt_id = self.last_prompt_id.clone();
1622 let tool_use_metadata = ToolUseMetadata {
1623 model: model.clone(),
1624 thread_id: self.id.clone(),
1625 prompt_id: prompt_id.clone(),
1626 };
1627
1628 let completion_mode = request
1629 .mode
1630 .unwrap_or(zed_llm_client::CompletionMode::Normal);
1631
1632 self.last_received_chunk_at = Some(Instant::now());
1633
1634 let task = cx.spawn(async move |thread, cx| {
1635 let stream_completion_future = model.stream_completion(request, &cx);
1636 let initial_token_usage =
1637 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1638 let stream_completion = async {
1639 let mut events = stream_completion_future.await?;
1640
1641 let mut stop_reason = StopReason::EndTurn;
1642 let mut current_token_usage = TokenUsage::default();
1643
1644 thread
1645 .update(cx, |_thread, cx| {
1646 cx.emit(ThreadEvent::NewRequest);
1647 })
1648 .ok();
1649
1650 let mut request_assistant_message_id = None;
1651
1652 while let Some(event) = events.next().await {
1653 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1654 response_events
1655 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1656 }
1657
1658 thread.update(cx, |thread, cx| {
1659 match event? {
1660 LanguageModelCompletionEvent::StartMessage { .. } => {
1661 request_assistant_message_id =
1662 Some(thread.insert_assistant_message(
1663 vec![MessageSegment::Text(String::new())],
1664 cx,
1665 ));
1666 }
1667 LanguageModelCompletionEvent::Stop(reason) => {
1668 stop_reason = reason;
1669 }
1670 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1671 thread.update_token_usage_at_last_message(token_usage);
1672 thread.cumulative_token_usage = thread.cumulative_token_usage
1673 + token_usage
1674 - current_token_usage;
1675 current_token_usage = token_usage;
1676 }
1677 LanguageModelCompletionEvent::Text(chunk) => {
1678 thread.received_chunk();
1679
1680 cx.emit(ThreadEvent::ReceivedTextChunk);
1681 if let Some(last_message) = thread.messages.last_mut() {
1682 if last_message.role == Role::Assistant
1683 && !thread.tool_use.has_tool_results(last_message.id)
1684 {
1685 last_message.push_text(&chunk);
1686 cx.emit(ThreadEvent::StreamedAssistantText(
1687 last_message.id,
1688 chunk,
1689 ));
1690 } else {
1691 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1692 // of a new Assistant response.
1693 //
1694 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1695 // will result in duplicating the text of the chunk in the rendered Markdown.
1696 request_assistant_message_id =
1697 Some(thread.insert_assistant_message(
1698 vec![MessageSegment::Text(chunk.to_string())],
1699 cx,
1700 ));
1701 };
1702 }
1703 }
1704 LanguageModelCompletionEvent::Thinking {
1705 text: chunk,
1706 signature,
1707 } => {
1708 thread.received_chunk();
1709
1710 if let Some(last_message) = thread.messages.last_mut() {
1711 if last_message.role == Role::Assistant
1712 && !thread.tool_use.has_tool_results(last_message.id)
1713 {
1714 last_message.push_thinking(&chunk, signature);
1715 cx.emit(ThreadEvent::StreamedAssistantThinking(
1716 last_message.id,
1717 chunk,
1718 ));
1719 } else {
1720 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1721 // of a new Assistant response.
1722 //
1723 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1724 // will result in duplicating the text of the chunk in the rendered Markdown.
1725 request_assistant_message_id =
1726 Some(thread.insert_assistant_message(
1727 vec![MessageSegment::Thinking {
1728 text: chunk.to_string(),
1729 signature,
1730 }],
1731 cx,
1732 ));
1733 };
1734 }
1735 }
1736 LanguageModelCompletionEvent::RedactedThinking { data } => {
1737 thread.received_chunk();
1738
1739 if let Some(last_message) = thread.messages.last_mut() {
1740 if last_message.role == Role::Assistant
1741 && !thread.tool_use.has_tool_results(last_message.id)
1742 {
1743 last_message.push_redacted_thinking(data);
1744 } else {
1745 request_assistant_message_id =
1746 Some(thread.insert_assistant_message(
1747 vec![MessageSegment::RedactedThinking(data)],
1748 cx,
1749 ));
1750 };
1751 }
1752 }
1753 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1754 let last_assistant_message_id = request_assistant_message_id
1755 .unwrap_or_else(|| {
1756 let new_assistant_message_id =
1757 thread.insert_assistant_message(vec![], cx);
1758 request_assistant_message_id =
1759 Some(new_assistant_message_id);
1760 new_assistant_message_id
1761 });
1762
1763 let tool_use_id = tool_use.id.clone();
1764 let streamed_input = if tool_use.is_input_complete {
1765 None
1766 } else {
1767 Some((&tool_use.input).clone())
1768 };
1769
1770 let ui_text = thread.tool_use.request_tool_use(
1771 last_assistant_message_id,
1772 tool_use,
1773 tool_use_metadata.clone(),
1774 cx,
1775 );
1776
1777 if let Some(input) = streamed_input {
1778 cx.emit(ThreadEvent::StreamedToolUse {
1779 tool_use_id,
1780 ui_text,
1781 input,
1782 });
1783 }
1784 }
1785 LanguageModelCompletionEvent::ToolUseJsonParseError {
1786 id,
1787 tool_name,
1788 raw_input: invalid_input_json,
1789 json_parse_error,
1790 } => {
1791 thread.receive_invalid_tool_json(
1792 id,
1793 tool_name,
1794 invalid_input_json,
1795 json_parse_error,
1796 window,
1797 cx,
1798 );
1799 }
1800 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1801 if let Some(completion) = thread
1802 .pending_completions
1803 .iter_mut()
1804 .find(|completion| completion.id == pending_completion_id)
1805 {
1806 match status_update {
1807 CompletionRequestStatus::Queued { position } => {
1808 completion.queue_state =
1809 QueueState::Queued { position };
1810 }
1811 CompletionRequestStatus::Started => {
1812 completion.queue_state = QueueState::Started;
1813 }
1814 CompletionRequestStatus::Failed {
1815 code,
1816 message,
1817 request_id: _,
1818 retry_after,
1819 } => {
1820 return Err(
1821 LanguageModelCompletionError::from_cloud_failure(
1822 model.upstream_provider_name(),
1823 code,
1824 message,
1825 retry_after.map(Duration::from_secs_f64),
1826 ),
1827 );
1828 }
1829 CompletionRequestStatus::UsageUpdated { amount, limit } => {
1830 thread.update_model_request_usage(
1831 amount as u32,
1832 limit,
1833 cx,
1834 );
1835 }
1836 CompletionRequestStatus::ToolUseLimitReached => {
1837 thread.tool_use_limit_reached = true;
1838 cx.emit(ThreadEvent::ToolUseLimitReached);
1839 }
1840 }
1841 }
1842 }
1843 }
1844
1845 thread.touch_updated_at();
1846 cx.emit(ThreadEvent::StreamedCompletion);
1847 cx.notify();
1848
1849 thread.auto_capture_telemetry(cx);
1850 Ok(())
1851 })??;
1852
1853 smol::future::yield_now().await;
1854 }
1855
1856 thread.update(cx, |thread, cx| {
1857 thread.last_received_chunk_at = None;
1858 thread
1859 .pending_completions
1860 .retain(|completion| completion.id != pending_completion_id);
1861
1862 // If there is a response without tool use, summarize the message. Otherwise,
1863 // allow two tool uses before summarizing.
1864 if matches!(thread.summary, ThreadSummary::Pending)
1865 && thread.messages.len() >= 2
1866 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1867 {
1868 thread.summarize(cx);
1869 }
1870 })?;
1871
1872 anyhow::Ok(stop_reason)
1873 };
1874
1875 let result = stream_completion.await;
1876 let mut retry_scheduled = false;
1877
1878 thread
1879 .update(cx, |thread, cx| {
1880 thread.finalize_pending_checkpoint(cx);
1881 match result.as_ref() {
1882 Ok(stop_reason) => {
1883 match stop_reason {
1884 StopReason::ToolUse => {
1885 let tool_uses =
1886 thread.use_pending_tools(window, model.clone(), cx);
1887 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1888 }
1889 StopReason::EndTurn | StopReason::MaxTokens => {
1890 thread.project.update(cx, |project, cx| {
1891 project.set_agent_location(None, cx);
1892 });
1893 }
1894 StopReason::Refusal => {
1895 thread.project.update(cx, |project, cx| {
1896 project.set_agent_location(None, cx);
1897 });
1898
1899 // Remove the turn that was refused.
1900 //
1901 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1902 {
1903 let mut messages_to_remove = Vec::new();
1904
1905 for (ix, message) in
1906 thread.messages.iter().enumerate().rev()
1907 {
1908 messages_to_remove.push(message.id);
1909
1910 if message.role == Role::User {
1911 if ix == 0 {
1912 break;
1913 }
1914
1915 if let Some(prev_message) =
1916 thread.messages.get(ix - 1)
1917 {
1918 if prev_message.role == Role::Assistant {
1919 break;
1920 }
1921 }
1922 }
1923 }
1924
1925 for message_id in messages_to_remove {
1926 thread.delete_message(message_id, cx);
1927 }
1928 }
1929
1930 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1931 header: "Language model refusal".into(),
1932 message:
1933 "Model refused to generate content for safety reasons."
1934 .into(),
1935 }));
1936 }
1937 }
1938
1939 // We successfully completed, so cancel any remaining retries.
1940 thread.retry_state = None;
1941 }
1942 Err(error) => {
1943 thread.project.update(cx, |project, cx| {
1944 project.set_agent_location(None, cx);
1945 });
1946
1947 if error.is::<PaymentRequiredError>() {
1948 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1949 } else if let Some(error) =
1950 error.downcast_ref::<ModelRequestLimitReachedError>()
1951 {
1952 cx.emit(ThreadEvent::ShowError(
1953 ThreadError::ModelRequestLimitReached { plan: error.plan },
1954 ));
1955 } else if let Some(completion_error) =
1956 error.downcast_ref::<LanguageModelCompletionError>()
1957 {
1958 match &completion_error {
1959 LanguageModelCompletionError::PromptTooLarge {
1960 tokens, ..
1961 } => {
1962 let tokens = tokens.unwrap_or_else(|| {
1963 // We didn't get an exact token count from the API, so fall back on our estimate.
1964 thread
1965 .total_token_usage()
1966 .map(|usage| usage.total)
1967 .unwrap_or(0)
1968 // We know the context window was exceeded in practice, so if our estimate was
1969 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1970 .max(
1971 model
1972 .max_token_count_for_mode(completion_mode)
1973 .saturating_add(1),
1974 )
1975 });
1976 thread.exceeded_window_error = Some(ExceededWindowError {
1977 model_id: model.id(),
1978 token_count: tokens,
1979 });
1980 cx.notify();
1981 }
1982 _ => {
1983 if let Some(retry_strategy) =
1984 Thread::get_retry_strategy(completion_error)
1985 {
1986 retry_scheduled = thread
1987 .handle_retryable_error_with_delay(
1988 &completion_error,
1989 Some(retry_strategy),
1990 model.clone(),
1991 intent,
1992 window,
1993 cx,
1994 );
1995 }
1996 }
1997 }
1998 }
1999
2000 if !retry_scheduled {
2001 thread.cancel_last_completion(window, cx);
2002 }
2003 }
2004 }
2005
2006 if !retry_scheduled {
2007 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2008 }
2009
2010 if let Some((request_callback, (request, response_events))) = thread
2011 .request_callback
2012 .as_mut()
2013 .zip(request_callback_parameters.as_ref())
2014 {
2015 request_callback(request, response_events);
2016 }
2017
2018 thread.auto_capture_telemetry(cx);
2019
2020 if let Ok(initial_usage) = initial_token_usage {
2021 let usage = thread.cumulative_token_usage - initial_usage;
2022
2023 telemetry::event!(
2024 "Assistant Thread Completion",
2025 thread_id = thread.id().to_string(),
2026 prompt_id = prompt_id,
2027 model = model.telemetry_id(),
2028 model_provider = model.provider_id().to_string(),
2029 input_tokens = usage.input_tokens,
2030 output_tokens = usage.output_tokens,
2031 cache_creation_input_tokens = usage.cache_creation_input_tokens,
2032 cache_read_input_tokens = usage.cache_read_input_tokens,
2033 );
2034 }
2035 })
2036 .ok();
2037 });
2038
2039 self.pending_completions.push(PendingCompletion {
2040 id: pending_completion_id,
2041 queue_state: QueueState::Sending,
2042 _task: task,
2043 });
2044 }
2045
2046 pub fn summarize(&mut self, cx: &mut Context<Self>) {
2047 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2048 println!("No thread summary model");
2049 return;
2050 };
2051
2052 if !model.provider.is_authenticated(cx) {
2053 return;
2054 }
2055
2056 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2057
2058 let request = self.to_summarize_request(
2059 &model.model,
2060 CompletionIntent::ThreadSummarization,
2061 added_user_message.into(),
2062 cx,
2063 );
2064
2065 self.summary = ThreadSummary::Generating;
2066
2067 self.pending_summary = cx.spawn(async move |this, cx| {
2068 let result = async {
2069 let mut messages = model.model.stream_completion(request, &cx).await?;
2070
2071 let mut new_summary = String::new();
2072 while let Some(event) = messages.next().await {
2073 let Ok(event) = event else {
2074 continue;
2075 };
2076 let text = match event {
2077 LanguageModelCompletionEvent::Text(text) => text,
2078 LanguageModelCompletionEvent::StatusUpdate(
2079 CompletionRequestStatus::UsageUpdated { amount, limit },
2080 ) => {
2081 this.update(cx, |thread, cx| {
2082 thread.update_model_request_usage(amount as u32, limit, cx);
2083 })?;
2084 continue;
2085 }
2086 _ => continue,
2087 };
2088
2089 let mut lines = text.lines();
2090 new_summary.extend(lines.next());
2091
2092 // Stop if the LLM generated multiple lines.
2093 if lines.next().is_some() {
2094 break;
2095 }
2096 }
2097
2098 anyhow::Ok(new_summary)
2099 }
2100 .await;
2101
2102 this.update(cx, |this, cx| {
2103 match result {
2104 Ok(new_summary) => {
2105 if new_summary.is_empty() {
2106 this.summary = ThreadSummary::Error;
2107 } else {
2108 this.summary = ThreadSummary::Ready(new_summary.into());
2109 }
2110 }
2111 Err(err) => {
2112 this.summary = ThreadSummary::Error;
2113 log::error!("Failed to generate thread summary: {}", err);
2114 }
2115 }
2116 cx.emit(ThreadEvent::SummaryGenerated);
2117 })
2118 .log_err()?;
2119
2120 Some(())
2121 });
2122 }
2123
2124 fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2125 use LanguageModelCompletionError::*;
2126
2127 // General strategy here:
2128 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2129 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), try multiple times with exponential backoff.
2130 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), just retry once.
2131 match error {
2132 HttpResponseError {
2133 status_code: StatusCode::TOO_MANY_REQUESTS,
2134 ..
2135 } => Some(RetryStrategy::ExponentialBackoff {
2136 initial_delay: BASE_RETRY_DELAY,
2137 max_attempts: MAX_RETRY_ATTEMPTS,
2138 }),
2139 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2140 Some(RetryStrategy::Fixed {
2141 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2142 max_attempts: MAX_RETRY_ATTEMPTS,
2143 })
2144 }
2145 UpstreamProviderError {
2146 status,
2147 retry_after,
2148 ..
2149 } => match *status {
2150 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2151 Some(RetryStrategy::Fixed {
2152 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2153 max_attempts: MAX_RETRY_ATTEMPTS,
2154 })
2155 }
2156 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2157 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2158 // Internal Server Error could be anything, so only retry once.
2159 max_attempts: 1,
2160 }),
2161 status => {
2162 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2163 // but we frequently get them in practice. See https://http.dev/529
2164 if status.as_u16() == 529 {
2165 Some(RetryStrategy::Fixed {
2166 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2167 max_attempts: MAX_RETRY_ATTEMPTS,
2168 })
2169 } else {
2170 None
2171 }
2172 }
2173 },
2174 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2175 delay: BASE_RETRY_DELAY,
2176 max_attempts: 1,
2177 }),
2178 ApiReadResponseError { .. }
2179 | HttpSend { .. }
2180 | DeserializeResponse { .. }
2181 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2182 delay: BASE_RETRY_DELAY,
2183 max_attempts: 1,
2184 }),
2185 // Retrying these errors definitely shouldn't help.
2186 HttpResponseError {
2187 status_code:
2188 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2189 ..
2190 }
2191 | SerializeRequest { .. }
2192 | BuildRequestBody { .. }
2193 | PromptTooLarge { .. }
2194 | AuthenticationError { .. }
2195 | PermissionError { .. }
2196 | ApiEndpointNotFound { .. }
2197 | NoApiKey { .. } => None,
2198 // Retry all other 4xx and 5xx errors once.
2199 HttpResponseError { status_code, .. }
2200 if status_code.is_client_error() || status_code.is_server_error() =>
2201 {
2202 Some(RetryStrategy::Fixed {
2203 delay: BASE_RETRY_DELAY,
2204 max_attempts: 1,
2205 })
2206 }
2207 // Conservatively assume that any other errors are non-retryable
2208 HttpResponseError { .. } | Other(..) => None,
2209 }
2210 }
2211
2212 fn handle_retryable_error_with_delay(
2213 &mut self,
2214 error: &LanguageModelCompletionError,
2215 strategy: Option<RetryStrategy>,
2216 model: Arc<dyn LanguageModel>,
2217 intent: CompletionIntent,
2218 window: Option<AnyWindowHandle>,
2219 cx: &mut Context<Self>,
2220 ) -> bool {
2221 let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2222 return false;
2223 };
2224
2225 let max_attempts = match &strategy {
2226 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2227 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2228 };
2229
2230 let retry_state = self.retry_state.get_or_insert(RetryState {
2231 attempt: 0,
2232 max_attempts,
2233 intent,
2234 });
2235
2236 retry_state.attempt += 1;
2237 let attempt = retry_state.attempt;
2238 let max_attempts = retry_state.max_attempts;
2239 let intent = retry_state.intent;
2240
2241 if attempt <= max_attempts {
2242 let delay = match &strategy {
2243 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2244 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2245 Duration::from_secs(delay_secs)
2246 }
2247 RetryStrategy::Fixed { delay, .. } => *delay,
2248 };
2249
2250 // Add a transient message to inform the user
2251 let delay_secs = delay.as_secs();
2252 let retry_message = if max_attempts == 1 {
2253 format!("{error}. Retrying in {delay_secs} seconds...")
2254 } else {
2255 format!(
2256 "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2257 in {delay_secs} seconds..."
2258 )
2259 };
2260 log::warn!(
2261 "Retrying completion request (attempt {attempt} of {max_attempts}) \
2262 in {delay_secs} seconds: {error:?}",
2263 );
2264
2265 // Add a UI-only message instead of a regular message
2266 let id = self.next_message_id.post_inc();
2267 self.messages.push(Message {
2268 id,
2269 role: Role::System,
2270 segments: vec![MessageSegment::Text(retry_message)],
2271 loaded_context: LoadedContext::default(),
2272 creases: Vec::new(),
2273 is_hidden: false,
2274 ui_only: true,
2275 });
2276 cx.emit(ThreadEvent::MessageAdded(id));
2277
2278 // Schedule the retry
2279 let thread_handle = cx.entity().downgrade();
2280
2281 cx.spawn(async move |_thread, cx| {
2282 cx.background_executor().timer(delay).await;
2283
2284 thread_handle
2285 .update(cx, |thread, cx| {
2286 // Retry the completion
2287 thread.send_to_model(model, intent, window, cx);
2288 })
2289 .log_err();
2290 })
2291 .detach();
2292
2293 true
2294 } else {
2295 // Max retries exceeded
2296 self.retry_state = None;
2297
2298 // Stop generating since we're giving up on retrying.
2299 self.pending_completions.clear();
2300
2301 false
2302 }
2303 }
2304
2305 pub fn start_generating_detailed_summary_if_needed(
2306 &mut self,
2307 thread_store: WeakEntity<ThreadStore>,
2308 cx: &mut Context<Self>,
2309 ) {
2310 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2311 return;
2312 };
2313
2314 match &*self.detailed_summary_rx.borrow() {
2315 DetailedSummaryState::Generating { message_id, .. }
2316 | DetailedSummaryState::Generated { message_id, .. }
2317 if *message_id == last_message_id =>
2318 {
2319 // Already up-to-date
2320 return;
2321 }
2322 _ => {}
2323 }
2324
2325 let Some(ConfiguredModel { model, provider }) =
2326 LanguageModelRegistry::read_global(cx).thread_summary_model()
2327 else {
2328 return;
2329 };
2330
2331 if !provider.is_authenticated(cx) {
2332 return;
2333 }
2334
2335 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2336
2337 let request = self.to_summarize_request(
2338 &model,
2339 CompletionIntent::ThreadContextSummarization,
2340 added_user_message.into(),
2341 cx,
2342 );
2343
2344 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2345 message_id: last_message_id,
2346 };
2347
2348 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2349 // be better to allow the old task to complete, but this would require logic for choosing
2350 // which result to prefer (the old task could complete after the new one, resulting in a
2351 // stale summary).
2352 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2353 let stream = model.stream_completion_text(request, &cx);
2354 let Some(mut messages) = stream.await.log_err() else {
2355 thread
2356 .update(cx, |thread, _cx| {
2357 *thread.detailed_summary_tx.borrow_mut() =
2358 DetailedSummaryState::NotGenerated;
2359 })
2360 .ok()?;
2361 return None;
2362 };
2363
2364 let mut new_detailed_summary = String::new();
2365
2366 while let Some(chunk) = messages.stream.next().await {
2367 if let Some(chunk) = chunk.log_err() {
2368 new_detailed_summary.push_str(&chunk);
2369 }
2370 }
2371
2372 thread
2373 .update(cx, |thread, _cx| {
2374 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2375 text: new_detailed_summary.into(),
2376 message_id: last_message_id,
2377 };
2378 })
2379 .ok()?;
2380
2381 // Save thread so its summary can be reused later
2382 if let Some(thread) = thread.upgrade() {
2383 if let Ok(Ok(save_task)) = cx.update(|cx| {
2384 thread_store
2385 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2386 }) {
2387 save_task.await.log_err();
2388 }
2389 }
2390
2391 Some(())
2392 });
2393 }
2394
2395 pub async fn wait_for_detailed_summary_or_text(
2396 this: &Entity<Self>,
2397 cx: &mut AsyncApp,
2398 ) -> Option<SharedString> {
2399 let mut detailed_summary_rx = this
2400 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2401 .ok()?;
2402 loop {
2403 match detailed_summary_rx.recv().await? {
2404 DetailedSummaryState::Generating { .. } => {}
2405 DetailedSummaryState::NotGenerated => {
2406 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2407 }
2408 DetailedSummaryState::Generated { text, .. } => return Some(text),
2409 }
2410 }
2411 }
2412
2413 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2414 self.detailed_summary_rx
2415 .borrow()
2416 .text()
2417 .unwrap_or_else(|| self.text().into())
2418 }
2419
2420 pub fn is_generating_detailed_summary(&self) -> bool {
2421 matches!(
2422 &*self.detailed_summary_rx.borrow(),
2423 DetailedSummaryState::Generating { .. }
2424 )
2425 }
2426
2427 pub fn use_pending_tools(
2428 &mut self,
2429 window: Option<AnyWindowHandle>,
2430 model: Arc<dyn LanguageModel>,
2431 cx: &mut Context<Self>,
2432 ) -> Vec<PendingToolUse> {
2433 self.auto_capture_telemetry(cx);
2434 let request =
2435 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2436 let pending_tool_uses = self
2437 .tool_use
2438 .pending_tool_uses()
2439 .into_iter()
2440 .filter(|tool_use| tool_use.status.is_idle())
2441 .cloned()
2442 .collect::<Vec<_>>();
2443
2444 for tool_use in pending_tool_uses.iter() {
2445 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2446 }
2447
2448 pending_tool_uses
2449 }
2450
2451 fn use_pending_tool(
2452 &mut self,
2453 tool_use: PendingToolUse,
2454 request: Arc<LanguageModelRequest>,
2455 model: Arc<dyn LanguageModel>,
2456 window: Option<AnyWindowHandle>,
2457 cx: &mut Context<Self>,
2458 ) {
2459 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2460 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2461 };
2462
2463 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2464 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2465 }
2466
2467 if tool.needs_confirmation(&tool_use.input, cx)
2468 && !AgentSettings::get_global(cx).always_allow_tool_actions
2469 {
2470 self.tool_use.confirm_tool_use(
2471 tool_use.id,
2472 tool_use.ui_text,
2473 tool_use.input,
2474 request,
2475 tool,
2476 );
2477 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2478 } else {
2479 self.run_tool(
2480 tool_use.id,
2481 tool_use.ui_text,
2482 tool_use.input,
2483 request,
2484 tool,
2485 model,
2486 window,
2487 cx,
2488 );
2489 }
2490 }
2491
2492 pub fn handle_hallucinated_tool_use(
2493 &mut self,
2494 tool_use_id: LanguageModelToolUseId,
2495 hallucinated_tool_name: Arc<str>,
2496 window: Option<AnyWindowHandle>,
2497 cx: &mut Context<Thread>,
2498 ) {
2499 let available_tools = self.profile.enabled_tools(cx);
2500
2501 let tool_list = available_tools
2502 .iter()
2503 .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2504 .collect::<Vec<_>>()
2505 .join("\n");
2506
2507 let error_message = format!(
2508 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2509 hallucinated_tool_name, tool_list
2510 );
2511
2512 let pending_tool_use = self.tool_use.insert_tool_output(
2513 tool_use_id.clone(),
2514 hallucinated_tool_name,
2515 Err(anyhow!("Missing tool call: {error_message}")),
2516 self.configured_model.as_ref(),
2517 self.completion_mode,
2518 );
2519
2520 cx.emit(ThreadEvent::MissingToolUse {
2521 tool_use_id: tool_use_id.clone(),
2522 ui_text: error_message.into(),
2523 });
2524
2525 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2526 }
2527
2528 pub fn receive_invalid_tool_json(
2529 &mut self,
2530 tool_use_id: LanguageModelToolUseId,
2531 tool_name: Arc<str>,
2532 invalid_json: Arc<str>,
2533 error: String,
2534 window: Option<AnyWindowHandle>,
2535 cx: &mut Context<Thread>,
2536 ) {
2537 log::error!("The model returned invalid input JSON: {invalid_json}");
2538
2539 let pending_tool_use = self.tool_use.insert_tool_output(
2540 tool_use_id.clone(),
2541 tool_name,
2542 Err(anyhow!("Error parsing input JSON: {error}")),
2543 self.configured_model.as_ref(),
2544 self.completion_mode,
2545 );
2546 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2547 pending_tool_use.ui_text.clone()
2548 } else {
2549 log::error!(
2550 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2551 );
2552 format!("Unknown tool {}", tool_use_id).into()
2553 };
2554
2555 cx.emit(ThreadEvent::InvalidToolInput {
2556 tool_use_id: tool_use_id.clone(),
2557 ui_text,
2558 invalid_input_json: invalid_json,
2559 });
2560
2561 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2562 }
2563
2564 pub fn run_tool(
2565 &mut self,
2566 tool_use_id: LanguageModelToolUseId,
2567 ui_text: impl Into<SharedString>,
2568 input: serde_json::Value,
2569 request: Arc<LanguageModelRequest>,
2570 tool: Arc<dyn Tool>,
2571 model: Arc<dyn LanguageModel>,
2572 window: Option<AnyWindowHandle>,
2573 cx: &mut Context<Thread>,
2574 ) {
2575 let task =
2576 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2577 self.tool_use
2578 .run_pending_tool(tool_use_id, ui_text.into(), task);
2579 }
2580
2581 fn spawn_tool_use(
2582 &mut self,
2583 tool_use_id: LanguageModelToolUseId,
2584 request: Arc<LanguageModelRequest>,
2585 input: serde_json::Value,
2586 tool: Arc<dyn Tool>,
2587 model: Arc<dyn LanguageModel>,
2588 window: Option<AnyWindowHandle>,
2589 cx: &mut Context<Thread>,
2590 ) -> Task<()> {
2591 let tool_name: Arc<str> = tool.name().into();
2592
2593 let tool_result = tool.run(
2594 input,
2595 request,
2596 self.project.clone(),
2597 self.action_log.clone(),
2598 model,
2599 window,
2600 cx,
2601 );
2602
2603 // Store the card separately if it exists
2604 if let Some(card) = tool_result.card.clone() {
2605 self.tool_use
2606 .insert_tool_result_card(tool_use_id.clone(), card);
2607 }
2608
2609 cx.spawn({
2610 async move |thread: WeakEntity<Thread>, cx| {
2611 let output = tool_result.output.await;
2612
2613 thread
2614 .update(cx, |thread, cx| {
2615 let pending_tool_use = thread.tool_use.insert_tool_output(
2616 tool_use_id.clone(),
2617 tool_name,
2618 output,
2619 thread.configured_model.as_ref(),
2620 thread.completion_mode,
2621 );
2622 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2623 })
2624 .ok();
2625 }
2626 })
2627 }
2628
2629 fn tool_finished(
2630 &mut self,
2631 tool_use_id: LanguageModelToolUseId,
2632 pending_tool_use: Option<PendingToolUse>,
2633 canceled: bool,
2634 window: Option<AnyWindowHandle>,
2635 cx: &mut Context<Self>,
2636 ) {
2637 if self.all_tools_finished() {
2638 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2639 if !canceled {
2640 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2641 }
2642 self.auto_capture_telemetry(cx);
2643 }
2644 }
2645
2646 cx.emit(ThreadEvent::ToolFinished {
2647 tool_use_id,
2648 pending_tool_use,
2649 });
2650 }
2651
2652 /// Cancels the last pending completion, if there are any pending.
2653 ///
2654 /// Returns whether a completion was canceled.
2655 pub fn cancel_last_completion(
2656 &mut self,
2657 window: Option<AnyWindowHandle>,
2658 cx: &mut Context<Self>,
2659 ) -> bool {
2660 let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2661
2662 self.retry_state = None;
2663
2664 for pending_tool_use in self.tool_use.cancel_pending() {
2665 canceled = true;
2666 self.tool_finished(
2667 pending_tool_use.id.clone(),
2668 Some(pending_tool_use),
2669 true,
2670 window,
2671 cx,
2672 );
2673 }
2674
2675 if canceled {
2676 cx.emit(ThreadEvent::CompletionCanceled);
2677
2678 // When canceled, we always want to insert the checkpoint.
2679 // (We skip over finalize_pending_checkpoint, because it
2680 // would conclude we didn't have anything to insert here.)
2681 if let Some(checkpoint) = self.pending_checkpoint.take() {
2682 self.insert_checkpoint(checkpoint, cx);
2683 }
2684 } else {
2685 self.finalize_pending_checkpoint(cx);
2686 }
2687
2688 canceled
2689 }
2690
2691 /// Signals that any in-progress editing should be canceled.
2692 ///
2693 /// This method is used to notify listeners (like ActiveThread) that
2694 /// they should cancel any editing operations.
2695 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2696 cx.emit(ThreadEvent::CancelEditing);
2697 }
2698
2699 pub fn feedback(&self) -> Option<ThreadFeedback> {
2700 self.feedback
2701 }
2702
2703 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2704 self.message_feedback.get(&message_id).copied()
2705 }
2706
2707 pub fn report_message_feedback(
2708 &mut self,
2709 message_id: MessageId,
2710 feedback: ThreadFeedback,
2711 cx: &mut Context<Self>,
2712 ) -> Task<Result<()>> {
2713 if self.message_feedback.get(&message_id) == Some(&feedback) {
2714 return Task::ready(Ok(()));
2715 }
2716
2717 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2718 let serialized_thread = self.serialize(cx);
2719 let thread_id = self.id().clone();
2720 let client = self.project.read(cx).client();
2721
2722 let enabled_tool_names: Vec<String> = self
2723 .profile
2724 .enabled_tools(cx)
2725 .iter()
2726 .map(|(name, _)| name.clone().into())
2727 .collect();
2728
2729 self.message_feedback.insert(message_id, feedback);
2730
2731 cx.notify();
2732
2733 let message_content = self
2734 .message(message_id)
2735 .map(|msg| msg.to_string())
2736 .unwrap_or_default();
2737
2738 cx.background_spawn(async move {
2739 let final_project_snapshot = final_project_snapshot.await;
2740 let serialized_thread = serialized_thread.await?;
2741 let thread_data =
2742 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2743
2744 let rating = match feedback {
2745 ThreadFeedback::Positive => "positive",
2746 ThreadFeedback::Negative => "negative",
2747 };
2748 telemetry::event!(
2749 "Assistant Thread Rated",
2750 rating,
2751 thread_id,
2752 enabled_tool_names,
2753 message_id = message_id.0,
2754 message_content,
2755 thread_data,
2756 final_project_snapshot
2757 );
2758 client.telemetry().flush_events().await;
2759
2760 Ok(())
2761 })
2762 }
2763
2764 pub fn report_feedback(
2765 &mut self,
2766 feedback: ThreadFeedback,
2767 cx: &mut Context<Self>,
2768 ) -> Task<Result<()>> {
2769 let last_assistant_message_id = self
2770 .messages
2771 .iter()
2772 .rev()
2773 .find(|msg| msg.role == Role::Assistant)
2774 .map(|msg| msg.id);
2775
2776 if let Some(message_id) = last_assistant_message_id {
2777 self.report_message_feedback(message_id, feedback, cx)
2778 } else {
2779 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2780 let serialized_thread = self.serialize(cx);
2781 let thread_id = self.id().clone();
2782 let client = self.project.read(cx).client();
2783 self.feedback = Some(feedback);
2784 cx.notify();
2785
2786 cx.background_spawn(async move {
2787 let final_project_snapshot = final_project_snapshot.await;
2788 let serialized_thread = serialized_thread.await?;
2789 let thread_data = serde_json::to_value(serialized_thread)
2790 .unwrap_or_else(|_| serde_json::Value::Null);
2791
2792 let rating = match feedback {
2793 ThreadFeedback::Positive => "positive",
2794 ThreadFeedback::Negative => "negative",
2795 };
2796 telemetry::event!(
2797 "Assistant Thread Rated",
2798 rating,
2799 thread_id,
2800 thread_data,
2801 final_project_snapshot
2802 );
2803 client.telemetry().flush_events().await;
2804
2805 Ok(())
2806 })
2807 }
2808 }
2809
2810 /// Create a snapshot of the current project state including git information and unsaved buffers.
2811 fn project_snapshot(
2812 project: Entity<Project>,
2813 cx: &mut Context<Self>,
2814 ) -> Task<Arc<ProjectSnapshot>> {
2815 let git_store = project.read(cx).git_store().clone();
2816 let worktree_snapshots: Vec<_> = project
2817 .read(cx)
2818 .visible_worktrees(cx)
2819 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2820 .collect();
2821
2822 cx.spawn(async move |_, cx| {
2823 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2824
2825 let mut unsaved_buffers = Vec::new();
2826 cx.update(|app_cx| {
2827 let buffer_store = project.read(app_cx).buffer_store();
2828 for buffer_handle in buffer_store.read(app_cx).buffers() {
2829 let buffer = buffer_handle.read(app_cx);
2830 if buffer.is_dirty() {
2831 if let Some(file) = buffer.file() {
2832 let path = file.path().to_string_lossy().to_string();
2833 unsaved_buffers.push(path);
2834 }
2835 }
2836 }
2837 })
2838 .ok();
2839
2840 Arc::new(ProjectSnapshot {
2841 worktree_snapshots,
2842 unsaved_buffer_paths: unsaved_buffers,
2843 timestamp: Utc::now(),
2844 })
2845 })
2846 }
2847
2848 fn worktree_snapshot(
2849 worktree: Entity<project::Worktree>,
2850 git_store: Entity<GitStore>,
2851 cx: &App,
2852 ) -> Task<WorktreeSnapshot> {
2853 cx.spawn(async move |cx| {
2854 // Get worktree path and snapshot
2855 let worktree_info = cx.update(|app_cx| {
2856 let worktree = worktree.read(app_cx);
2857 let path = worktree.abs_path().to_string_lossy().to_string();
2858 let snapshot = worktree.snapshot();
2859 (path, snapshot)
2860 });
2861
2862 let Ok((worktree_path, _snapshot)) = worktree_info else {
2863 return WorktreeSnapshot {
2864 worktree_path: String::new(),
2865 git_state: None,
2866 };
2867 };
2868
2869 let git_state = git_store
2870 .update(cx, |git_store, cx| {
2871 git_store
2872 .repositories()
2873 .values()
2874 .find(|repo| {
2875 repo.read(cx)
2876 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2877 .is_some()
2878 })
2879 .cloned()
2880 })
2881 .ok()
2882 .flatten()
2883 .map(|repo| {
2884 repo.update(cx, |repo, _| {
2885 let current_branch =
2886 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2887 repo.send_job(None, |state, _| async move {
2888 let RepositoryState::Local { backend, .. } = state else {
2889 return GitState {
2890 remote_url: None,
2891 head_sha: None,
2892 current_branch,
2893 diff: None,
2894 };
2895 };
2896
2897 let remote_url = backend.remote_url("origin");
2898 let head_sha = backend.head_sha().await;
2899 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2900
2901 GitState {
2902 remote_url,
2903 head_sha,
2904 current_branch,
2905 diff,
2906 }
2907 })
2908 })
2909 });
2910
2911 let git_state = match git_state {
2912 Some(git_state) => match git_state.ok() {
2913 Some(git_state) => git_state.await.ok(),
2914 None => None,
2915 },
2916 None => None,
2917 };
2918
2919 WorktreeSnapshot {
2920 worktree_path,
2921 git_state,
2922 }
2923 })
2924 }
2925
2926 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2927 let mut markdown = Vec::new();
2928
2929 let summary = self.summary().or_default();
2930 writeln!(markdown, "# {summary}\n")?;
2931
2932 for message in self.messages() {
2933 writeln!(
2934 markdown,
2935 "## {role}\n",
2936 role = match message.role {
2937 Role::User => "User",
2938 Role::Assistant => "Agent",
2939 Role::System => "System",
2940 }
2941 )?;
2942
2943 if !message.loaded_context.text.is_empty() {
2944 writeln!(markdown, "{}", message.loaded_context.text)?;
2945 }
2946
2947 if !message.loaded_context.images.is_empty() {
2948 writeln!(
2949 markdown,
2950 "\n{} images attached as context.\n",
2951 message.loaded_context.images.len()
2952 )?;
2953 }
2954
2955 for segment in &message.segments {
2956 match segment {
2957 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2958 MessageSegment::Thinking { text, .. } => {
2959 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2960 }
2961 MessageSegment::RedactedThinking(_) => {}
2962 }
2963 }
2964
2965 for tool_use in self.tool_uses_for_message(message.id, cx) {
2966 writeln!(
2967 markdown,
2968 "**Use Tool: {} ({})**",
2969 tool_use.name, tool_use.id
2970 )?;
2971 writeln!(markdown, "```json")?;
2972 writeln!(
2973 markdown,
2974 "{}",
2975 serde_json::to_string_pretty(&tool_use.input)?
2976 )?;
2977 writeln!(markdown, "```")?;
2978 }
2979
2980 for tool_result in self.tool_results_for_message(message.id) {
2981 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2982 if tool_result.is_error {
2983 write!(markdown, " (Error)")?;
2984 }
2985
2986 writeln!(markdown, "**\n")?;
2987 match &tool_result.content {
2988 LanguageModelToolResultContent::Text(text) => {
2989 writeln!(markdown, "{text}")?;
2990 }
2991 LanguageModelToolResultContent::Image(image) => {
2992 writeln!(markdown, "", image.source)?;
2993 }
2994 }
2995
2996 if let Some(output) = tool_result.output.as_ref() {
2997 writeln!(
2998 markdown,
2999 "\n\nDebug Output:\n\n```json\n{}\n```\n",
3000 serde_json::to_string_pretty(output)?
3001 )?;
3002 }
3003 }
3004 }
3005
3006 Ok(String::from_utf8_lossy(&markdown).to_string())
3007 }
3008
3009 pub fn keep_edits_in_range(
3010 &mut self,
3011 buffer: Entity<language::Buffer>,
3012 buffer_range: Range<language::Anchor>,
3013 cx: &mut Context<Self>,
3014 ) {
3015 self.action_log.update(cx, |action_log, cx| {
3016 action_log.keep_edits_in_range(buffer, buffer_range, cx)
3017 });
3018 }
3019
3020 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3021 self.action_log
3022 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3023 }
3024
3025 pub fn reject_edits_in_ranges(
3026 &mut self,
3027 buffer: Entity<language::Buffer>,
3028 buffer_ranges: Vec<Range<language::Anchor>>,
3029 cx: &mut Context<Self>,
3030 ) -> Task<Result<()>> {
3031 self.action_log.update(cx, |action_log, cx| {
3032 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3033 })
3034 }
3035
3036 pub fn action_log(&self) -> &Entity<ActionLog> {
3037 &self.action_log
3038 }
3039
3040 pub fn project(&self) -> &Entity<Project> {
3041 &self.project
3042 }
3043
3044 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3045 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3046 return;
3047 }
3048
3049 let now = Instant::now();
3050 if let Some(last) = self.last_auto_capture_at {
3051 if now.duration_since(last).as_secs() < 10 {
3052 return;
3053 }
3054 }
3055
3056 self.last_auto_capture_at = Some(now);
3057
3058 let thread_id = self.id().clone();
3059 let github_login = self
3060 .project
3061 .read(cx)
3062 .user_store()
3063 .read(cx)
3064 .current_user()
3065 .map(|user| user.github_login.clone());
3066 let client = self.project.read(cx).client();
3067 let serialize_task = self.serialize(cx);
3068
3069 cx.background_executor()
3070 .spawn(async move {
3071 if let Ok(serialized_thread) = serialize_task.await {
3072 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3073 telemetry::event!(
3074 "Agent Thread Auto-Captured",
3075 thread_id = thread_id.to_string(),
3076 thread_data = thread_data,
3077 auto_capture_reason = "tracked_user",
3078 github_login = github_login
3079 );
3080
3081 client.telemetry().flush_events().await;
3082 }
3083 }
3084 })
3085 .detach();
3086 }
3087
3088 pub fn cumulative_token_usage(&self) -> TokenUsage {
3089 self.cumulative_token_usage
3090 }
3091
3092 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3093 let Some(model) = self.configured_model.as_ref() else {
3094 return TotalTokenUsage::default();
3095 };
3096
3097 let max = model
3098 .model
3099 .max_token_count_for_mode(self.completion_mode().into());
3100
3101 let index = self
3102 .messages
3103 .iter()
3104 .position(|msg| msg.id == message_id)
3105 .unwrap_or(0);
3106
3107 if index == 0 {
3108 return TotalTokenUsage { total: 0, max };
3109 }
3110
3111 let token_usage = &self
3112 .request_token_usage
3113 .get(index - 1)
3114 .cloned()
3115 .unwrap_or_default();
3116
3117 TotalTokenUsage {
3118 total: token_usage.total_tokens(),
3119 max,
3120 }
3121 }
3122
3123 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3124 let model = self.configured_model.as_ref()?;
3125
3126 let max = model
3127 .model
3128 .max_token_count_for_mode(self.completion_mode().into());
3129
3130 if let Some(exceeded_error) = &self.exceeded_window_error {
3131 if model.model.id() == exceeded_error.model_id {
3132 return Some(TotalTokenUsage {
3133 total: exceeded_error.token_count,
3134 max,
3135 });
3136 }
3137 }
3138
3139 let total = self
3140 .token_usage_at_last_message()
3141 .unwrap_or_default()
3142 .total_tokens();
3143
3144 Some(TotalTokenUsage { total, max })
3145 }
3146
3147 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3148 self.request_token_usage
3149 .get(self.messages.len().saturating_sub(1))
3150 .or_else(|| self.request_token_usage.last())
3151 .cloned()
3152 }
3153
3154 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3155 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3156 self.request_token_usage
3157 .resize(self.messages.len(), placeholder);
3158
3159 if let Some(last) = self.request_token_usage.last_mut() {
3160 *last = token_usage;
3161 }
3162 }
3163
3164 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3165 self.project.update(cx, |project, cx| {
3166 project.user_store().update(cx, |user_store, cx| {
3167 user_store.update_model_request_usage(
3168 ModelRequestUsage(RequestUsage {
3169 amount: amount as i32,
3170 limit,
3171 }),
3172 cx,
3173 )
3174 })
3175 });
3176 }
3177
3178 pub fn deny_tool_use(
3179 &mut self,
3180 tool_use_id: LanguageModelToolUseId,
3181 tool_name: Arc<str>,
3182 window: Option<AnyWindowHandle>,
3183 cx: &mut Context<Self>,
3184 ) {
3185 let err = Err(anyhow::anyhow!(
3186 "Permission to run tool action denied by user"
3187 ));
3188
3189 self.tool_use.insert_tool_output(
3190 tool_use_id.clone(),
3191 tool_name,
3192 err,
3193 self.configured_model.as_ref(),
3194 self.completion_mode,
3195 );
3196 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3197 }
3198}
3199
3200#[derive(Debug, Clone, Error)]
3201pub enum ThreadError {
3202 #[error("Payment required")]
3203 PaymentRequired,
3204 #[error("Model request limit reached")]
3205 ModelRequestLimitReached { plan: Plan },
3206 #[error("Message {header}: {message}")]
3207 Message {
3208 header: SharedString,
3209 message: SharedString,
3210 },
3211}
3212
3213#[derive(Debug, Clone)]
3214pub enum ThreadEvent {
3215 ShowError(ThreadError),
3216 StreamedCompletion,
3217 ReceivedTextChunk,
3218 NewRequest,
3219 StreamedAssistantText(MessageId, String),
3220 StreamedAssistantThinking(MessageId, String),
3221 StreamedToolUse {
3222 tool_use_id: LanguageModelToolUseId,
3223 ui_text: Arc<str>,
3224 input: serde_json::Value,
3225 },
3226 MissingToolUse {
3227 tool_use_id: LanguageModelToolUseId,
3228 ui_text: Arc<str>,
3229 },
3230 InvalidToolInput {
3231 tool_use_id: LanguageModelToolUseId,
3232 ui_text: Arc<str>,
3233 invalid_input_json: Arc<str>,
3234 },
3235 Stopped(Result<StopReason, Arc<anyhow::Error>>),
3236 MessageAdded(MessageId),
3237 MessageEdited(MessageId),
3238 MessageDeleted(MessageId),
3239 SummaryGenerated,
3240 SummaryChanged,
3241 UsePendingTools {
3242 tool_uses: Vec<PendingToolUse>,
3243 },
3244 ToolFinished {
3245 #[allow(unused)]
3246 tool_use_id: LanguageModelToolUseId,
3247 /// The pending tool use that corresponds to this tool.
3248 pending_tool_use: Option<PendingToolUse>,
3249 },
3250 CheckpointChanged,
3251 ToolConfirmationNeeded,
3252 ToolUseLimitReached,
3253 CancelEditing,
3254 CompletionCanceled,
3255 ProfileChanged,
3256}
3257
3258impl EventEmitter<ThreadEvent> for Thread {}
3259
3260struct PendingCompletion {
3261 id: usize,
3262 queue_state: QueueState,
3263 _task: Task<()>,
3264}
3265
3266#[cfg(test)]
3267mod tests {
3268 use super::*;
3269 use crate::{
3270 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3271 };
3272
3273 // Test-specific constants
3274 const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3275 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3276 use assistant_tool::ToolRegistry;
3277 use assistant_tools;
3278 use futures::StreamExt;
3279 use futures::future::BoxFuture;
3280 use futures::stream::BoxStream;
3281 use gpui::TestAppContext;
3282 use http_client;
3283 use indoc::indoc;
3284 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3285 use language_model::{
3286 LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3287 LanguageModelProviderName, LanguageModelToolChoice,
3288 };
3289 use parking_lot::Mutex;
3290 use project::{FakeFs, Project};
3291 use prompt_store::PromptBuilder;
3292 use serde_json::json;
3293 use settings::{Settings, SettingsStore};
3294 use std::sync::Arc;
3295 use std::time::Duration;
3296 use theme::ThemeSettings;
3297 use util::path;
3298 use workspace::Workspace;
3299
3300 #[gpui::test]
3301 async fn test_message_with_context(cx: &mut TestAppContext) {
3302 init_test_settings(cx);
3303
3304 let project = create_test_project(
3305 cx,
3306 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3307 )
3308 .await;
3309
3310 let (_workspace, _thread_store, thread, context_store, model) =
3311 setup_test_environment(cx, project.clone()).await;
3312
3313 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3314 .await
3315 .unwrap();
3316
3317 let context =
3318 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3319 let loaded_context = cx
3320 .update(|cx| load_context(vec![context], &project, &None, cx))
3321 .await;
3322
3323 // Insert user message with context
3324 let message_id = thread.update(cx, |thread, cx| {
3325 thread.insert_user_message(
3326 "Please explain this code",
3327 loaded_context,
3328 None,
3329 Vec::new(),
3330 cx,
3331 )
3332 });
3333
3334 // Check content and context in message object
3335 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3336
3337 // Use different path format strings based on platform for the test
3338 #[cfg(windows)]
3339 let path_part = r"test\code.rs";
3340 #[cfg(not(windows))]
3341 let path_part = "test/code.rs";
3342
3343 let expected_context = format!(
3344 r#"
3345<context>
3346The following items were attached by the user. They are up-to-date and don't need to be re-read.
3347
3348<files>
3349```rs {path_part}
3350fn main() {{
3351 println!("Hello, world!");
3352}}
3353```
3354</files>
3355</context>
3356"#
3357 );
3358
3359 assert_eq!(message.role, Role::User);
3360 assert_eq!(message.segments.len(), 1);
3361 assert_eq!(
3362 message.segments[0],
3363 MessageSegment::Text("Please explain this code".to_string())
3364 );
3365 assert_eq!(message.loaded_context.text, expected_context);
3366
3367 // Check message in request
3368 let request = thread.update(cx, |thread, cx| {
3369 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3370 });
3371
3372 assert_eq!(request.messages.len(), 2);
3373 let expected_full_message = format!("{}Please explain this code", expected_context);
3374 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3375 }
3376
3377 #[gpui::test]
3378 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3379 init_test_settings(cx);
3380
3381 let project = create_test_project(
3382 cx,
3383 json!({
3384 "file1.rs": "fn function1() {}\n",
3385 "file2.rs": "fn function2() {}\n",
3386 "file3.rs": "fn function3() {}\n",
3387 "file4.rs": "fn function4() {}\n",
3388 }),
3389 )
3390 .await;
3391
3392 let (_, _thread_store, thread, context_store, model) =
3393 setup_test_environment(cx, project.clone()).await;
3394
3395 // First message with context 1
3396 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3397 .await
3398 .unwrap();
3399 let new_contexts = context_store.update(cx, |store, cx| {
3400 store.new_context_for_thread(thread.read(cx), None)
3401 });
3402 assert_eq!(new_contexts.len(), 1);
3403 let loaded_context = cx
3404 .update(|cx| load_context(new_contexts, &project, &None, cx))
3405 .await;
3406 let message1_id = thread.update(cx, |thread, cx| {
3407 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3408 });
3409
3410 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3411 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3412 .await
3413 .unwrap();
3414 let new_contexts = context_store.update(cx, |store, cx| {
3415 store.new_context_for_thread(thread.read(cx), None)
3416 });
3417 assert_eq!(new_contexts.len(), 1);
3418 let loaded_context = cx
3419 .update(|cx| load_context(new_contexts, &project, &None, cx))
3420 .await;
3421 let message2_id = thread.update(cx, |thread, cx| {
3422 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3423 });
3424
3425 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3426 //
3427 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3428 .await
3429 .unwrap();
3430 let new_contexts = context_store.update(cx, |store, cx| {
3431 store.new_context_for_thread(thread.read(cx), None)
3432 });
3433 assert_eq!(new_contexts.len(), 1);
3434 let loaded_context = cx
3435 .update(|cx| load_context(new_contexts, &project, &None, cx))
3436 .await;
3437 let message3_id = thread.update(cx, |thread, cx| {
3438 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3439 });
3440
3441 // Check what contexts are included in each message
3442 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3443 (
3444 thread.message(message1_id).unwrap().clone(),
3445 thread.message(message2_id).unwrap().clone(),
3446 thread.message(message3_id).unwrap().clone(),
3447 )
3448 });
3449
3450 // First message should include context 1
3451 assert!(message1.loaded_context.text.contains("file1.rs"));
3452
3453 // Second message should include only context 2 (not 1)
3454 assert!(!message2.loaded_context.text.contains("file1.rs"));
3455 assert!(message2.loaded_context.text.contains("file2.rs"));
3456
3457 // Third message should include only context 3 (not 1 or 2)
3458 assert!(!message3.loaded_context.text.contains("file1.rs"));
3459 assert!(!message3.loaded_context.text.contains("file2.rs"));
3460 assert!(message3.loaded_context.text.contains("file3.rs"));
3461
3462 // Check entire request to make sure all contexts are properly included
3463 let request = thread.update(cx, |thread, cx| {
3464 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3465 });
3466
3467 // The request should contain all 3 messages
3468 assert_eq!(request.messages.len(), 4);
3469
3470 // Check that the contexts are properly formatted in each message
3471 assert!(request.messages[1].string_contents().contains("file1.rs"));
3472 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3473 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3474
3475 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3476 assert!(request.messages[2].string_contents().contains("file2.rs"));
3477 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3478
3479 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3480 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3481 assert!(request.messages[3].string_contents().contains("file3.rs"));
3482
3483 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3484 .await
3485 .unwrap();
3486 let new_contexts = context_store.update(cx, |store, cx| {
3487 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3488 });
3489 assert_eq!(new_contexts.len(), 3);
3490 let loaded_context = cx
3491 .update(|cx| load_context(new_contexts, &project, &None, cx))
3492 .await
3493 .loaded_context;
3494
3495 assert!(!loaded_context.text.contains("file1.rs"));
3496 assert!(loaded_context.text.contains("file2.rs"));
3497 assert!(loaded_context.text.contains("file3.rs"));
3498 assert!(loaded_context.text.contains("file4.rs"));
3499
3500 let new_contexts = context_store.update(cx, |store, cx| {
3501 // Remove file4.rs
3502 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3503 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3504 });
3505 assert_eq!(new_contexts.len(), 2);
3506 let loaded_context = cx
3507 .update(|cx| load_context(new_contexts, &project, &None, cx))
3508 .await
3509 .loaded_context;
3510
3511 assert!(!loaded_context.text.contains("file1.rs"));
3512 assert!(loaded_context.text.contains("file2.rs"));
3513 assert!(loaded_context.text.contains("file3.rs"));
3514 assert!(!loaded_context.text.contains("file4.rs"));
3515
3516 let new_contexts = context_store.update(cx, |store, cx| {
3517 // Remove file3.rs
3518 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3519 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3520 });
3521 assert_eq!(new_contexts.len(), 1);
3522 let loaded_context = cx
3523 .update(|cx| load_context(new_contexts, &project, &None, cx))
3524 .await
3525 .loaded_context;
3526
3527 assert!(!loaded_context.text.contains("file1.rs"));
3528 assert!(loaded_context.text.contains("file2.rs"));
3529 assert!(!loaded_context.text.contains("file3.rs"));
3530 assert!(!loaded_context.text.contains("file4.rs"));
3531 }
3532
3533 #[gpui::test]
3534 async fn test_message_without_files(cx: &mut TestAppContext) {
3535 init_test_settings(cx);
3536
3537 let project = create_test_project(
3538 cx,
3539 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3540 )
3541 .await;
3542
3543 let (_, _thread_store, thread, _context_store, model) =
3544 setup_test_environment(cx, project.clone()).await;
3545
3546 // Insert user message without any context (empty context vector)
3547 let message_id = thread.update(cx, |thread, cx| {
3548 thread.insert_user_message(
3549 "What is the best way to learn Rust?",
3550 ContextLoadResult::default(),
3551 None,
3552 Vec::new(),
3553 cx,
3554 )
3555 });
3556
3557 // Check content and context in message object
3558 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3559
3560 // Context should be empty when no files are included
3561 assert_eq!(message.role, Role::User);
3562 assert_eq!(message.segments.len(), 1);
3563 assert_eq!(
3564 message.segments[0],
3565 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3566 );
3567 assert_eq!(message.loaded_context.text, "");
3568
3569 // Check message in request
3570 let request = thread.update(cx, |thread, cx| {
3571 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3572 });
3573
3574 assert_eq!(request.messages.len(), 2);
3575 assert_eq!(
3576 request.messages[1].string_contents(),
3577 "What is the best way to learn Rust?"
3578 );
3579
3580 // Add second message, also without context
3581 let message2_id = thread.update(cx, |thread, cx| {
3582 thread.insert_user_message(
3583 "Are there any good books?",
3584 ContextLoadResult::default(),
3585 None,
3586 Vec::new(),
3587 cx,
3588 )
3589 });
3590
3591 let message2 =
3592 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3593 assert_eq!(message2.loaded_context.text, "");
3594
3595 // Check that both messages appear in the request
3596 let request = thread.update(cx, |thread, cx| {
3597 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3598 });
3599
3600 assert_eq!(request.messages.len(), 3);
3601 assert_eq!(
3602 request.messages[1].string_contents(),
3603 "What is the best way to learn Rust?"
3604 );
3605 assert_eq!(
3606 request.messages[2].string_contents(),
3607 "Are there any good books?"
3608 );
3609 }
3610
3611 #[gpui::test]
3612 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3613 init_test_settings(cx);
3614
3615 let project = create_test_project(
3616 cx,
3617 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3618 )
3619 .await;
3620
3621 let (_workspace, _thread_store, thread, context_store, model) =
3622 setup_test_environment(cx, project.clone()).await;
3623
3624 // Add a buffer to the context. This will be a tracked buffer
3625 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3626 .await
3627 .unwrap();
3628
3629 let context = context_store
3630 .read_with(cx, |store, _| store.context().next().cloned())
3631 .unwrap();
3632 let loaded_context = cx
3633 .update(|cx| load_context(vec![context], &project, &None, cx))
3634 .await;
3635
3636 // Insert user message and assistant response
3637 thread.update(cx, |thread, cx| {
3638 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3639 thread.insert_assistant_message(
3640 vec![MessageSegment::Text("This code prints 42.".into())],
3641 cx,
3642 );
3643 });
3644
3645 // We shouldn't have a stale buffer notification yet
3646 let notifications = thread.read_with(cx, |thread, _| {
3647 find_tool_uses(thread, "project_notifications")
3648 });
3649 assert!(
3650 notifications.is_empty(),
3651 "Should not have stale buffer notification before buffer is modified"
3652 );
3653
3654 // Modify the buffer
3655 buffer.update(cx, |buffer, cx| {
3656 buffer.edit(
3657 [(1..1, "\n println!(\"Added a new line\");\n")],
3658 None,
3659 cx,
3660 );
3661 });
3662
3663 // Insert another user message
3664 thread.update(cx, |thread, cx| {
3665 thread.insert_user_message(
3666 "What does the code do now?",
3667 ContextLoadResult::default(),
3668 None,
3669 Vec::new(),
3670 cx,
3671 )
3672 });
3673
3674 // Check for the stale buffer warning
3675 thread.update(cx, |thread, cx| {
3676 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3677 });
3678
3679 let notifications = thread.read_with(cx, |thread, _cx| {
3680 find_tool_uses(thread, "project_notifications")
3681 });
3682
3683 let [notification] = notifications.as_slice() else {
3684 panic!("Should have a `project_notifications` tool use");
3685 };
3686
3687 let Some(notification_content) = notification.content.to_str() else {
3688 panic!("`project_notifications` should return text");
3689 };
3690
3691 let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3692
3693 These files have changed since the last read:
3694 - code.rs
3695 "};
3696 assert_eq!(notification_content, expected_content);
3697
3698 // Insert another user message and flush notifications again
3699 thread.update(cx, |thread, cx| {
3700 thread.insert_user_message(
3701 "Can you tell me more?",
3702 ContextLoadResult::default(),
3703 None,
3704 Vec::new(),
3705 cx,
3706 )
3707 });
3708
3709 thread.update(cx, |thread, cx| {
3710 thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3711 });
3712
3713 // There should be no new notifications (we already flushed one)
3714 let notifications = thread.read_with(cx, |thread, _cx| {
3715 find_tool_uses(thread, "project_notifications")
3716 });
3717
3718 assert_eq!(
3719 notifications.len(),
3720 1,
3721 "Should still have only one notification after second flush - no duplicates"
3722 );
3723 }
3724
3725 fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3726 thread
3727 .messages()
3728 .flat_map(|message| {
3729 thread
3730 .tool_results_for_message(message.id)
3731 .into_iter()
3732 .filter(|result| result.tool_name == tool_name.into())
3733 .cloned()
3734 .collect::<Vec<_>>()
3735 })
3736 .collect()
3737 }
3738
3739 #[gpui::test]
3740 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3741 init_test_settings(cx);
3742
3743 let project = create_test_project(
3744 cx,
3745 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3746 )
3747 .await;
3748
3749 let (_workspace, thread_store, thread, _context_store, _model) =
3750 setup_test_environment(cx, project.clone()).await;
3751
3752 // Check that we are starting with the default profile
3753 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3754 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3755 assert_eq!(
3756 profile,
3757 AgentProfile::new(AgentProfileId::default(), tool_set)
3758 );
3759 }
3760
3761 #[gpui::test]
3762 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3763 init_test_settings(cx);
3764
3765 let project = create_test_project(
3766 cx,
3767 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3768 )
3769 .await;
3770
3771 let (_workspace, thread_store, thread, _context_store, _model) =
3772 setup_test_environment(cx, project.clone()).await;
3773
3774 // Profile gets serialized with default values
3775 let serialized = thread
3776 .update(cx, |thread, cx| thread.serialize(cx))
3777 .await
3778 .unwrap();
3779
3780 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3781
3782 let deserialized = cx.update(|cx| {
3783 thread.update(cx, |thread, cx| {
3784 Thread::deserialize(
3785 thread.id.clone(),
3786 serialized,
3787 thread.project.clone(),
3788 thread.tools.clone(),
3789 thread.prompt_builder.clone(),
3790 thread.project_context.clone(),
3791 None,
3792 cx,
3793 )
3794 })
3795 });
3796 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3797
3798 assert_eq!(
3799 deserialized.profile,
3800 AgentProfile::new(AgentProfileId::default(), tool_set)
3801 );
3802 }
3803
3804 #[gpui::test]
3805 async fn test_temperature_setting(cx: &mut TestAppContext) {
3806 init_test_settings(cx);
3807
3808 let project = create_test_project(
3809 cx,
3810 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3811 )
3812 .await;
3813
3814 let (_workspace, _thread_store, thread, _context_store, model) =
3815 setup_test_environment(cx, project.clone()).await;
3816
3817 // Both model and provider
3818 cx.update(|cx| {
3819 AgentSettings::override_global(
3820 AgentSettings {
3821 model_parameters: vec![LanguageModelParameters {
3822 provider: Some(model.provider_id().0.to_string().into()),
3823 model: Some(model.id().0.clone()),
3824 temperature: Some(0.66),
3825 }],
3826 ..AgentSettings::get_global(cx).clone()
3827 },
3828 cx,
3829 );
3830 });
3831
3832 let request = thread.update(cx, |thread, cx| {
3833 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3834 });
3835 assert_eq!(request.temperature, Some(0.66));
3836
3837 // Only model
3838 cx.update(|cx| {
3839 AgentSettings::override_global(
3840 AgentSettings {
3841 model_parameters: vec![LanguageModelParameters {
3842 provider: None,
3843 model: Some(model.id().0.clone()),
3844 temperature: Some(0.66),
3845 }],
3846 ..AgentSettings::get_global(cx).clone()
3847 },
3848 cx,
3849 );
3850 });
3851
3852 let request = thread.update(cx, |thread, cx| {
3853 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3854 });
3855 assert_eq!(request.temperature, Some(0.66));
3856
3857 // Only provider
3858 cx.update(|cx| {
3859 AgentSettings::override_global(
3860 AgentSettings {
3861 model_parameters: vec![LanguageModelParameters {
3862 provider: Some(model.provider_id().0.to_string().into()),
3863 model: None,
3864 temperature: Some(0.66),
3865 }],
3866 ..AgentSettings::get_global(cx).clone()
3867 },
3868 cx,
3869 );
3870 });
3871
3872 let request = thread.update(cx, |thread, cx| {
3873 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3874 });
3875 assert_eq!(request.temperature, Some(0.66));
3876
3877 // Same model name, different provider
3878 cx.update(|cx| {
3879 AgentSettings::override_global(
3880 AgentSettings {
3881 model_parameters: vec![LanguageModelParameters {
3882 provider: Some("anthropic".into()),
3883 model: Some(model.id().0.clone()),
3884 temperature: Some(0.66),
3885 }],
3886 ..AgentSettings::get_global(cx).clone()
3887 },
3888 cx,
3889 );
3890 });
3891
3892 let request = thread.update(cx, |thread, cx| {
3893 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3894 });
3895 assert_eq!(request.temperature, None);
3896 }
3897
3898 #[gpui::test]
3899 async fn test_thread_summary(cx: &mut TestAppContext) {
3900 init_test_settings(cx);
3901
3902 let project = create_test_project(cx, json!({})).await;
3903
3904 let (_, _thread_store, thread, _context_store, model) =
3905 setup_test_environment(cx, project.clone()).await;
3906
3907 // Initial state should be pending
3908 thread.read_with(cx, |thread, _| {
3909 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3910 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3911 });
3912
3913 // Manually setting the summary should not be allowed in this state
3914 thread.update(cx, |thread, cx| {
3915 thread.set_summary("This should not work", cx);
3916 });
3917
3918 thread.read_with(cx, |thread, _| {
3919 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3920 });
3921
3922 // Send a message
3923 thread.update(cx, |thread, cx| {
3924 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3925 thread.send_to_model(
3926 model.clone(),
3927 CompletionIntent::ThreadSummarization,
3928 None,
3929 cx,
3930 );
3931 });
3932
3933 let fake_model = model.as_fake();
3934 simulate_successful_response(&fake_model, cx);
3935
3936 // Should start generating summary when there are >= 2 messages
3937 thread.read_with(cx, |thread, _| {
3938 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3939 });
3940
3941 // Should not be able to set the summary while generating
3942 thread.update(cx, |thread, cx| {
3943 thread.set_summary("This should not work either", cx);
3944 });
3945
3946 thread.read_with(cx, |thread, _| {
3947 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3948 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3949 });
3950
3951 cx.run_until_parked();
3952 fake_model.stream_last_completion_response("Brief");
3953 fake_model.stream_last_completion_response(" Introduction");
3954 fake_model.end_last_completion_stream();
3955 cx.run_until_parked();
3956
3957 // Summary should be set
3958 thread.read_with(cx, |thread, _| {
3959 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3960 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3961 });
3962
3963 // Now we should be able to set a summary
3964 thread.update(cx, |thread, cx| {
3965 thread.set_summary("Brief Intro", cx);
3966 });
3967
3968 thread.read_with(cx, |thread, _| {
3969 assert_eq!(thread.summary().or_default(), "Brief Intro");
3970 });
3971
3972 // Test setting an empty summary (should default to DEFAULT)
3973 thread.update(cx, |thread, cx| {
3974 thread.set_summary("", cx);
3975 });
3976
3977 thread.read_with(cx, |thread, _| {
3978 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3979 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3980 });
3981 }
3982
3983 #[gpui::test]
3984 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3985 init_test_settings(cx);
3986
3987 let project = create_test_project(cx, json!({})).await;
3988
3989 let (_, _thread_store, thread, _context_store, model) =
3990 setup_test_environment(cx, project.clone()).await;
3991
3992 test_summarize_error(&model, &thread, cx);
3993
3994 // Now we should be able to set a summary
3995 thread.update(cx, |thread, cx| {
3996 thread.set_summary("Brief Intro", cx);
3997 });
3998
3999 thread.read_with(cx, |thread, _| {
4000 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4001 assert_eq!(thread.summary().or_default(), "Brief Intro");
4002 });
4003 }
4004
4005 #[gpui::test]
4006 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
4007 init_test_settings(cx);
4008
4009 let project = create_test_project(cx, json!({})).await;
4010
4011 let (_, _thread_store, thread, _context_store, model) =
4012 setup_test_environment(cx, project.clone()).await;
4013
4014 test_summarize_error(&model, &thread, cx);
4015
4016 // Sending another message should not trigger another summarize request
4017 thread.update(cx, |thread, cx| {
4018 thread.insert_user_message(
4019 "How are you?",
4020 ContextLoadResult::default(),
4021 None,
4022 vec![],
4023 cx,
4024 );
4025 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4026 });
4027
4028 let fake_model = model.as_fake();
4029 simulate_successful_response(&fake_model, cx);
4030
4031 thread.read_with(cx, |thread, _| {
4032 // State is still Error, not Generating
4033 assert!(matches!(thread.summary(), ThreadSummary::Error));
4034 });
4035
4036 // But the summarize request can be invoked manually
4037 thread.update(cx, |thread, cx| {
4038 thread.summarize(cx);
4039 });
4040
4041 thread.read_with(cx, |thread, _| {
4042 assert!(matches!(thread.summary(), ThreadSummary::Generating));
4043 });
4044
4045 cx.run_until_parked();
4046 fake_model.stream_last_completion_response("A successful summary");
4047 fake_model.end_last_completion_stream();
4048 cx.run_until_parked();
4049
4050 thread.read_with(cx, |thread, _| {
4051 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4052 assert_eq!(thread.summary().or_default(), "A successful summary");
4053 });
4054 }
4055
4056 // Helper to create a model that returns errors
4057 enum TestError {
4058 Overloaded,
4059 InternalServerError,
4060 }
4061
4062 struct ErrorInjector {
4063 inner: Arc<FakeLanguageModel>,
4064 error_type: TestError,
4065 }
4066
4067 impl ErrorInjector {
4068 fn new(error_type: TestError) -> Self {
4069 Self {
4070 inner: Arc::new(FakeLanguageModel::default()),
4071 error_type,
4072 }
4073 }
4074 }
4075
4076 impl LanguageModel for ErrorInjector {
4077 fn id(&self) -> LanguageModelId {
4078 self.inner.id()
4079 }
4080
4081 fn name(&self) -> LanguageModelName {
4082 self.inner.name()
4083 }
4084
4085 fn provider_id(&self) -> LanguageModelProviderId {
4086 self.inner.provider_id()
4087 }
4088
4089 fn provider_name(&self) -> LanguageModelProviderName {
4090 self.inner.provider_name()
4091 }
4092
4093 fn supports_tools(&self) -> bool {
4094 self.inner.supports_tools()
4095 }
4096
4097 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4098 self.inner.supports_tool_choice(choice)
4099 }
4100
4101 fn supports_images(&self) -> bool {
4102 self.inner.supports_images()
4103 }
4104
4105 fn telemetry_id(&self) -> String {
4106 self.inner.telemetry_id()
4107 }
4108
4109 fn max_token_count(&self) -> u64 {
4110 self.inner.max_token_count()
4111 }
4112
4113 fn count_tokens(
4114 &self,
4115 request: LanguageModelRequest,
4116 cx: &App,
4117 ) -> BoxFuture<'static, Result<u64>> {
4118 self.inner.count_tokens(request, cx)
4119 }
4120
4121 fn stream_completion(
4122 &self,
4123 _request: LanguageModelRequest,
4124 _cx: &AsyncApp,
4125 ) -> BoxFuture<
4126 'static,
4127 Result<
4128 BoxStream<
4129 'static,
4130 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4131 >,
4132 LanguageModelCompletionError,
4133 >,
4134 > {
4135 let error = match self.error_type {
4136 TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4137 provider: self.provider_name(),
4138 retry_after: None,
4139 },
4140 TestError::InternalServerError => {
4141 LanguageModelCompletionError::ApiInternalServerError {
4142 provider: self.provider_name(),
4143 message: "I'm a teapot orbiting the sun".to_string(),
4144 }
4145 }
4146 };
4147 async move {
4148 let stream = futures::stream::once(async move { Err(error) });
4149 Ok(stream.boxed())
4150 }
4151 .boxed()
4152 }
4153
4154 fn as_fake(&self) -> &FakeLanguageModel {
4155 &self.inner
4156 }
4157 }
4158
4159 #[gpui::test]
4160 async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4161 init_test_settings(cx);
4162
4163 let project = create_test_project(cx, json!({})).await;
4164 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4165
4166 // Create model that returns overloaded error
4167 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4168
4169 // Insert a user message
4170 thread.update(cx, |thread, cx| {
4171 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4172 });
4173
4174 // Start completion
4175 thread.update(cx, |thread, cx| {
4176 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4177 });
4178
4179 cx.run_until_parked();
4180
4181 thread.read_with(cx, |thread, _| {
4182 assert!(thread.retry_state.is_some(), "Should have retry state");
4183 let retry_state = thread.retry_state.as_ref().unwrap();
4184 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4185 assert_eq!(
4186 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4187 "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4188 );
4189 });
4190
4191 // Check that a retry message was added
4192 thread.read_with(cx, |thread, _| {
4193 let mut messages = thread.messages();
4194 assert!(
4195 messages.any(|msg| {
4196 msg.role == Role::System
4197 && msg.ui_only
4198 && msg.segments.iter().any(|seg| {
4199 if let MessageSegment::Text(text) = seg {
4200 text.contains("overloaded")
4201 && text
4202 .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4203 } else {
4204 false
4205 }
4206 })
4207 }),
4208 "Should have added a system retry message"
4209 );
4210 });
4211
4212 let retry_count = thread.update(cx, |thread, _| {
4213 thread
4214 .messages
4215 .iter()
4216 .filter(|m| {
4217 m.ui_only
4218 && m.segments.iter().any(|s| {
4219 if let MessageSegment::Text(text) = s {
4220 text.contains("Retrying") && text.contains("seconds")
4221 } else {
4222 false
4223 }
4224 })
4225 })
4226 .count()
4227 });
4228
4229 assert_eq!(retry_count, 1, "Should have one retry message");
4230 }
4231
4232 #[gpui::test]
4233 async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4234 init_test_settings(cx);
4235
4236 let project = create_test_project(cx, json!({})).await;
4237 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4238
4239 // Create model that returns internal server error
4240 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4241
4242 // Insert a user message
4243 thread.update(cx, |thread, cx| {
4244 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4245 });
4246
4247 // Start completion
4248 thread.update(cx, |thread, cx| {
4249 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4250 });
4251
4252 cx.run_until_parked();
4253
4254 // Check retry state on thread
4255 thread.read_with(cx, |thread, _| {
4256 assert!(thread.retry_state.is_some(), "Should have retry state");
4257 let retry_state = thread.retry_state.as_ref().unwrap();
4258 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4259 assert_eq!(
4260 retry_state.max_attempts, 1,
4261 "Should have correct max attempts"
4262 );
4263 });
4264
4265 // Check that a retry message was added with provider name
4266 thread.read_with(cx, |thread, _| {
4267 let mut messages = thread.messages();
4268 assert!(
4269 messages.any(|msg| {
4270 msg.role == Role::System
4271 && msg.ui_only
4272 && msg.segments.iter().any(|seg| {
4273 if let MessageSegment::Text(text) = seg {
4274 text.contains("internal")
4275 && text.contains("Fake")
4276 && text.contains("Retrying in")
4277 && !text.contains("attempt")
4278 } else {
4279 false
4280 }
4281 })
4282 }),
4283 "Should have added a system retry message with provider name"
4284 );
4285 });
4286
4287 // Count retry messages
4288 let retry_count = thread.update(cx, |thread, _| {
4289 thread
4290 .messages
4291 .iter()
4292 .filter(|m| {
4293 m.ui_only
4294 && m.segments.iter().any(|s| {
4295 if let MessageSegment::Text(text) = s {
4296 text.contains("Retrying") && text.contains("seconds")
4297 } else {
4298 false
4299 }
4300 })
4301 })
4302 .count()
4303 });
4304
4305 assert_eq!(retry_count, 1, "Should have one retry message");
4306 }
4307
4308 #[gpui::test]
4309 async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4310 init_test_settings(cx);
4311
4312 let project = create_test_project(cx, json!({})).await;
4313 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4314
4315 // Create model that returns internal server error
4316 let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4317
4318 // Insert a user message
4319 thread.update(cx, |thread, cx| {
4320 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4321 });
4322
4323 // Track retry events and completion count
4324 // Track completion events
4325 let completion_count = Arc::new(Mutex::new(0));
4326 let completion_count_clone = completion_count.clone();
4327
4328 let _subscription = thread.update(cx, |_, cx| {
4329 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4330 if let ThreadEvent::NewRequest = event {
4331 *completion_count_clone.lock() += 1;
4332 }
4333 })
4334 });
4335
4336 // First attempt
4337 thread.update(cx, |thread, cx| {
4338 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4339 });
4340 cx.run_until_parked();
4341
4342 // Should have scheduled first retry - count retry messages
4343 let retry_count = thread.update(cx, |thread, _| {
4344 thread
4345 .messages
4346 .iter()
4347 .filter(|m| {
4348 m.ui_only
4349 && m.segments.iter().any(|s| {
4350 if let MessageSegment::Text(text) = s {
4351 text.contains("Retrying") && text.contains("seconds")
4352 } else {
4353 false
4354 }
4355 })
4356 })
4357 .count()
4358 });
4359 assert_eq!(retry_count, 1, "Should have scheduled first retry");
4360
4361 // Check retry state
4362 thread.read_with(cx, |thread, _| {
4363 assert!(thread.retry_state.is_some(), "Should have retry state");
4364 let retry_state = thread.retry_state.as_ref().unwrap();
4365 assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4366 assert_eq!(
4367 retry_state.max_attempts, 1,
4368 "Internal server errors should only retry once"
4369 );
4370 });
4371
4372 // Advance clock for first retry
4373 cx.executor().advance_clock(BASE_RETRY_DELAY);
4374 cx.run_until_parked();
4375
4376 // Should have scheduled second retry - count retry messages
4377 let retry_count = thread.update(cx, |thread, _| {
4378 thread
4379 .messages
4380 .iter()
4381 .filter(|m| {
4382 m.ui_only
4383 && m.segments.iter().any(|s| {
4384 if let MessageSegment::Text(text) = s {
4385 text.contains("Retrying") && text.contains("seconds")
4386 } else {
4387 false
4388 }
4389 })
4390 })
4391 .count()
4392 });
4393 assert_eq!(
4394 retry_count, 1,
4395 "Should have only one retry for internal server errors"
4396 );
4397
4398 // For internal server errors, we only retry once and then give up
4399 // Check that retry_state is cleared after the single retry
4400 thread.read_with(cx, |thread, _| {
4401 assert!(
4402 thread.retry_state.is_none(),
4403 "Retry state should be cleared after single retry"
4404 );
4405 });
4406
4407 // Verify total attempts (1 initial + 1 retry)
4408 assert_eq!(
4409 *completion_count.lock(),
4410 2,
4411 "Should have attempted once plus 1 retry"
4412 );
4413 }
4414
4415 #[gpui::test]
4416 async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4417 init_test_settings(cx);
4418
4419 let project = create_test_project(cx, json!({})).await;
4420 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4421
4422 // Create model that returns overloaded error
4423 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4424
4425 // Insert a user message
4426 thread.update(cx, |thread, cx| {
4427 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4428 });
4429
4430 // Track events
4431 let stopped_with_error = Arc::new(Mutex::new(false));
4432 let stopped_with_error_clone = stopped_with_error.clone();
4433
4434 let _subscription = thread.update(cx, |_, cx| {
4435 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4436 if let ThreadEvent::Stopped(Err(_)) = event {
4437 *stopped_with_error_clone.lock() = true;
4438 }
4439 })
4440 });
4441
4442 // Start initial completion
4443 thread.update(cx, |thread, cx| {
4444 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4445 });
4446 cx.run_until_parked();
4447
4448 // Advance through all retries
4449 for _ in 0..MAX_RETRY_ATTEMPTS {
4450 cx.executor().advance_clock(BASE_RETRY_DELAY);
4451 cx.run_until_parked();
4452 }
4453
4454 let retry_count = thread.update(cx, |thread, _| {
4455 thread
4456 .messages
4457 .iter()
4458 .filter(|m| {
4459 m.ui_only
4460 && m.segments.iter().any(|s| {
4461 if let MessageSegment::Text(text) = s {
4462 text.contains("Retrying") && text.contains("seconds")
4463 } else {
4464 false
4465 }
4466 })
4467 })
4468 .count()
4469 });
4470
4471 // After max retries, should emit Stopped(Err(...)) event
4472 assert_eq!(
4473 retry_count, MAX_RETRY_ATTEMPTS as usize,
4474 "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4475 );
4476 assert!(
4477 *stopped_with_error.lock(),
4478 "Should emit Stopped(Err(...)) event after max retries exceeded"
4479 );
4480
4481 // Retry state should be cleared
4482 thread.read_with(cx, |thread, _| {
4483 assert!(
4484 thread.retry_state.is_none(),
4485 "Retry state should be cleared after max retries"
4486 );
4487
4488 // Verify we have the expected number of retry messages
4489 let retry_messages = thread
4490 .messages
4491 .iter()
4492 .filter(|msg| msg.ui_only && msg.role == Role::System)
4493 .count();
4494 assert_eq!(
4495 retry_messages, MAX_RETRY_ATTEMPTS as usize,
4496 "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4497 );
4498 });
4499 }
4500
4501 #[gpui::test]
4502 async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4503 init_test_settings(cx);
4504
4505 let project = create_test_project(cx, json!({})).await;
4506 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4507
4508 // We'll use a wrapper to switch behavior after first failure
4509 struct RetryTestModel {
4510 inner: Arc<FakeLanguageModel>,
4511 failed_once: Arc<Mutex<bool>>,
4512 }
4513
4514 impl LanguageModel for RetryTestModel {
4515 fn id(&self) -> LanguageModelId {
4516 self.inner.id()
4517 }
4518
4519 fn name(&self) -> LanguageModelName {
4520 self.inner.name()
4521 }
4522
4523 fn provider_id(&self) -> LanguageModelProviderId {
4524 self.inner.provider_id()
4525 }
4526
4527 fn provider_name(&self) -> LanguageModelProviderName {
4528 self.inner.provider_name()
4529 }
4530
4531 fn supports_tools(&self) -> bool {
4532 self.inner.supports_tools()
4533 }
4534
4535 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4536 self.inner.supports_tool_choice(choice)
4537 }
4538
4539 fn supports_images(&self) -> bool {
4540 self.inner.supports_images()
4541 }
4542
4543 fn telemetry_id(&self) -> String {
4544 self.inner.telemetry_id()
4545 }
4546
4547 fn max_token_count(&self) -> u64 {
4548 self.inner.max_token_count()
4549 }
4550
4551 fn count_tokens(
4552 &self,
4553 request: LanguageModelRequest,
4554 cx: &App,
4555 ) -> BoxFuture<'static, Result<u64>> {
4556 self.inner.count_tokens(request, cx)
4557 }
4558
4559 fn stream_completion(
4560 &self,
4561 request: LanguageModelRequest,
4562 cx: &AsyncApp,
4563 ) -> BoxFuture<
4564 'static,
4565 Result<
4566 BoxStream<
4567 'static,
4568 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4569 >,
4570 LanguageModelCompletionError,
4571 >,
4572 > {
4573 if !*self.failed_once.lock() {
4574 *self.failed_once.lock() = true;
4575 let provider = self.provider_name();
4576 // Return error on first attempt
4577 let stream = futures::stream::once(async move {
4578 Err(LanguageModelCompletionError::ServerOverloaded {
4579 provider,
4580 retry_after: None,
4581 })
4582 });
4583 async move { Ok(stream.boxed()) }.boxed()
4584 } else {
4585 // Succeed on retry
4586 self.inner.stream_completion(request, cx)
4587 }
4588 }
4589
4590 fn as_fake(&self) -> &FakeLanguageModel {
4591 &self.inner
4592 }
4593 }
4594
4595 let model = Arc::new(RetryTestModel {
4596 inner: Arc::new(FakeLanguageModel::default()),
4597 failed_once: Arc::new(Mutex::new(false)),
4598 });
4599
4600 // Insert a user message
4601 thread.update(cx, |thread, cx| {
4602 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4603 });
4604
4605 // Track message deletions
4606 // Track when retry completes successfully
4607 let retry_completed = Arc::new(Mutex::new(false));
4608 let retry_completed_clone = retry_completed.clone();
4609
4610 let _subscription = thread.update(cx, |_, cx| {
4611 cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4612 if let ThreadEvent::StreamedCompletion = event {
4613 *retry_completed_clone.lock() = true;
4614 }
4615 })
4616 });
4617
4618 // Start completion
4619 thread.update(cx, |thread, cx| {
4620 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4621 });
4622 cx.run_until_parked();
4623
4624 // Get the retry message ID
4625 let retry_message_id = thread.read_with(cx, |thread, _| {
4626 thread
4627 .messages()
4628 .find(|msg| msg.role == Role::System && msg.ui_only)
4629 .map(|msg| msg.id)
4630 .expect("Should have a retry message")
4631 });
4632
4633 // Wait for retry
4634 cx.executor().advance_clock(BASE_RETRY_DELAY);
4635 cx.run_until_parked();
4636
4637 // Stream some successful content
4638 let fake_model = model.as_fake();
4639 // After the retry, there should be a new pending completion
4640 let pending = fake_model.pending_completions();
4641 assert!(
4642 !pending.is_empty(),
4643 "Should have a pending completion after retry"
4644 );
4645 fake_model.stream_completion_response(&pending[0], "Success!");
4646 fake_model.end_completion_stream(&pending[0]);
4647 cx.run_until_parked();
4648
4649 // Check that the retry completed successfully
4650 assert!(
4651 *retry_completed.lock(),
4652 "Retry should have completed successfully"
4653 );
4654
4655 // Retry message should still exist but be marked as ui_only
4656 thread.read_with(cx, |thread, _| {
4657 let retry_msg = thread
4658 .message(retry_message_id)
4659 .expect("Retry message should still exist");
4660 assert!(retry_msg.ui_only, "Retry message should be ui_only");
4661 assert_eq!(
4662 retry_msg.role,
4663 Role::System,
4664 "Retry message should have System role"
4665 );
4666 });
4667 }
4668
4669 #[gpui::test]
4670 async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4671 init_test_settings(cx);
4672
4673 let project = create_test_project(cx, json!({})).await;
4674 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4675
4676 // Create a model that fails once then succeeds
4677 struct FailOnceModel {
4678 inner: Arc<FakeLanguageModel>,
4679 failed_once: Arc<Mutex<bool>>,
4680 }
4681
4682 impl LanguageModel for FailOnceModel {
4683 fn id(&self) -> LanguageModelId {
4684 self.inner.id()
4685 }
4686
4687 fn name(&self) -> LanguageModelName {
4688 self.inner.name()
4689 }
4690
4691 fn provider_id(&self) -> LanguageModelProviderId {
4692 self.inner.provider_id()
4693 }
4694
4695 fn provider_name(&self) -> LanguageModelProviderName {
4696 self.inner.provider_name()
4697 }
4698
4699 fn supports_tools(&self) -> bool {
4700 self.inner.supports_tools()
4701 }
4702
4703 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4704 self.inner.supports_tool_choice(choice)
4705 }
4706
4707 fn supports_images(&self) -> bool {
4708 self.inner.supports_images()
4709 }
4710
4711 fn telemetry_id(&self) -> String {
4712 self.inner.telemetry_id()
4713 }
4714
4715 fn max_token_count(&self) -> u64 {
4716 self.inner.max_token_count()
4717 }
4718
4719 fn count_tokens(
4720 &self,
4721 request: LanguageModelRequest,
4722 cx: &App,
4723 ) -> BoxFuture<'static, Result<u64>> {
4724 self.inner.count_tokens(request, cx)
4725 }
4726
4727 fn stream_completion(
4728 &self,
4729 request: LanguageModelRequest,
4730 cx: &AsyncApp,
4731 ) -> BoxFuture<
4732 'static,
4733 Result<
4734 BoxStream<
4735 'static,
4736 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4737 >,
4738 LanguageModelCompletionError,
4739 >,
4740 > {
4741 if !*self.failed_once.lock() {
4742 *self.failed_once.lock() = true;
4743 let provider = self.provider_name();
4744 // Return error on first attempt
4745 let stream = futures::stream::once(async move {
4746 Err(LanguageModelCompletionError::ServerOverloaded {
4747 provider,
4748 retry_after: None,
4749 })
4750 });
4751 async move { Ok(stream.boxed()) }.boxed()
4752 } else {
4753 // Succeed on retry
4754 self.inner.stream_completion(request, cx)
4755 }
4756 }
4757 }
4758
4759 let fail_once_model = Arc::new(FailOnceModel {
4760 inner: Arc::new(FakeLanguageModel::default()),
4761 failed_once: Arc::new(Mutex::new(false)),
4762 });
4763
4764 // Insert a user message
4765 thread.update(cx, |thread, cx| {
4766 thread.insert_user_message(
4767 "Test message",
4768 ContextLoadResult::default(),
4769 None,
4770 vec![],
4771 cx,
4772 );
4773 });
4774
4775 // Start completion with fail-once model
4776 thread.update(cx, |thread, cx| {
4777 thread.send_to_model(
4778 fail_once_model.clone(),
4779 CompletionIntent::UserPrompt,
4780 None,
4781 cx,
4782 );
4783 });
4784
4785 cx.run_until_parked();
4786
4787 // Verify retry state exists after first failure
4788 thread.read_with(cx, |thread, _| {
4789 assert!(
4790 thread.retry_state.is_some(),
4791 "Should have retry state after failure"
4792 );
4793 });
4794
4795 // Wait for retry delay
4796 cx.executor().advance_clock(BASE_RETRY_DELAY);
4797 cx.run_until_parked();
4798
4799 // The retry should now use our FailOnceModel which should succeed
4800 // We need to help the FakeLanguageModel complete the stream
4801 let inner_fake = fail_once_model.inner.clone();
4802
4803 // Wait a bit for the retry to start
4804 cx.run_until_parked();
4805
4806 // Check for pending completions and complete them
4807 if let Some(pending) = inner_fake.pending_completions().first() {
4808 inner_fake.stream_completion_response(pending, "Success!");
4809 inner_fake.end_completion_stream(pending);
4810 }
4811 cx.run_until_parked();
4812
4813 thread.read_with(cx, |thread, _| {
4814 assert!(
4815 thread.retry_state.is_none(),
4816 "Retry state should be cleared after successful completion"
4817 );
4818
4819 let has_assistant_message = thread
4820 .messages
4821 .iter()
4822 .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4823 assert!(
4824 has_assistant_message,
4825 "Should have an assistant message after successful retry"
4826 );
4827 });
4828 }
4829
4830 #[gpui::test]
4831 async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4832 init_test_settings(cx);
4833
4834 let project = create_test_project(cx, json!({})).await;
4835 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4836
4837 // Create a model that returns rate limit error with retry_after
4838 struct RateLimitModel {
4839 inner: Arc<FakeLanguageModel>,
4840 }
4841
4842 impl LanguageModel for RateLimitModel {
4843 fn id(&self) -> LanguageModelId {
4844 self.inner.id()
4845 }
4846
4847 fn name(&self) -> LanguageModelName {
4848 self.inner.name()
4849 }
4850
4851 fn provider_id(&self) -> LanguageModelProviderId {
4852 self.inner.provider_id()
4853 }
4854
4855 fn provider_name(&self) -> LanguageModelProviderName {
4856 self.inner.provider_name()
4857 }
4858
4859 fn supports_tools(&self) -> bool {
4860 self.inner.supports_tools()
4861 }
4862
4863 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4864 self.inner.supports_tool_choice(choice)
4865 }
4866
4867 fn supports_images(&self) -> bool {
4868 self.inner.supports_images()
4869 }
4870
4871 fn telemetry_id(&self) -> String {
4872 self.inner.telemetry_id()
4873 }
4874
4875 fn max_token_count(&self) -> u64 {
4876 self.inner.max_token_count()
4877 }
4878
4879 fn count_tokens(
4880 &self,
4881 request: LanguageModelRequest,
4882 cx: &App,
4883 ) -> BoxFuture<'static, Result<u64>> {
4884 self.inner.count_tokens(request, cx)
4885 }
4886
4887 fn stream_completion(
4888 &self,
4889 _request: LanguageModelRequest,
4890 _cx: &AsyncApp,
4891 ) -> BoxFuture<
4892 'static,
4893 Result<
4894 BoxStream<
4895 'static,
4896 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4897 >,
4898 LanguageModelCompletionError,
4899 >,
4900 > {
4901 let provider = self.provider_name();
4902 async move {
4903 let stream = futures::stream::once(async move {
4904 Err(LanguageModelCompletionError::RateLimitExceeded {
4905 provider,
4906 retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4907 })
4908 });
4909 Ok(stream.boxed())
4910 }
4911 .boxed()
4912 }
4913
4914 fn as_fake(&self) -> &FakeLanguageModel {
4915 &self.inner
4916 }
4917 }
4918
4919 let model = Arc::new(RateLimitModel {
4920 inner: Arc::new(FakeLanguageModel::default()),
4921 });
4922
4923 // Insert a user message
4924 thread.update(cx, |thread, cx| {
4925 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4926 });
4927
4928 // Start completion
4929 thread.update(cx, |thread, cx| {
4930 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4931 });
4932
4933 cx.run_until_parked();
4934
4935 let retry_count = thread.update(cx, |thread, _| {
4936 thread
4937 .messages
4938 .iter()
4939 .filter(|m| {
4940 m.ui_only
4941 && m.segments.iter().any(|s| {
4942 if let MessageSegment::Text(text) = s {
4943 text.contains("rate limit exceeded")
4944 } else {
4945 false
4946 }
4947 })
4948 })
4949 .count()
4950 });
4951 assert_eq!(retry_count, 1, "Should have scheduled one retry");
4952
4953 thread.read_with(cx, |thread, _| {
4954 assert!(
4955 thread.retry_state.is_some(),
4956 "Rate limit errors should set retry_state"
4957 );
4958 if let Some(retry_state) = &thread.retry_state {
4959 assert_eq!(
4960 retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4961 "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4962 );
4963 }
4964 });
4965
4966 // Verify we have one retry message
4967 thread.read_with(cx, |thread, _| {
4968 let retry_messages = thread
4969 .messages
4970 .iter()
4971 .filter(|msg| {
4972 msg.ui_only
4973 && msg.segments.iter().any(|seg| {
4974 if let MessageSegment::Text(text) = seg {
4975 text.contains("rate limit exceeded")
4976 } else {
4977 false
4978 }
4979 })
4980 })
4981 .count();
4982 assert_eq!(
4983 retry_messages, 1,
4984 "Should have one rate limit retry message"
4985 );
4986 });
4987
4988 // Check that retry message doesn't include attempt count
4989 thread.read_with(cx, |thread, _| {
4990 let retry_message = thread
4991 .messages
4992 .iter()
4993 .find(|msg| msg.role == Role::System && msg.ui_only)
4994 .expect("Should have a retry message");
4995
4996 // Check that the message contains attempt count since we use retry_state
4997 if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
4998 assert!(
4999 text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
5000 "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
5001 );
5002 assert!(
5003 text.contains("Retrying"),
5004 "Rate limit retry message should contain retry text"
5005 );
5006 }
5007 });
5008 }
5009
5010 #[gpui::test]
5011 async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5012 init_test_settings(cx);
5013
5014 let project = create_test_project(cx, json!({})).await;
5015 let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5016
5017 // Insert a regular user message
5018 thread.update(cx, |thread, cx| {
5019 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5020 });
5021
5022 // Insert a UI-only message (like our retry notifications)
5023 thread.update(cx, |thread, cx| {
5024 let id = thread.next_message_id.post_inc();
5025 thread.messages.push(Message {
5026 id,
5027 role: Role::System,
5028 segments: vec![MessageSegment::Text(
5029 "This is a UI-only message that should not be sent to the model".to_string(),
5030 )],
5031 loaded_context: LoadedContext::default(),
5032 creases: Vec::new(),
5033 is_hidden: true,
5034 ui_only: true,
5035 });
5036 cx.emit(ThreadEvent::MessageAdded(id));
5037 });
5038
5039 // Insert another regular message
5040 thread.update(cx, |thread, cx| {
5041 thread.insert_user_message(
5042 "How are you?",
5043 ContextLoadResult::default(),
5044 None,
5045 vec![],
5046 cx,
5047 );
5048 });
5049
5050 // Generate the completion request
5051 let request = thread.update(cx, |thread, cx| {
5052 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5053 });
5054
5055 // Verify that the request only contains non-UI-only messages
5056 // Should have system prompt + 2 user messages, but not the UI-only message
5057 let user_messages: Vec<_> = request
5058 .messages
5059 .iter()
5060 .filter(|msg| msg.role == Role::User)
5061 .collect();
5062 assert_eq!(
5063 user_messages.len(),
5064 2,
5065 "Should have exactly 2 user messages"
5066 );
5067
5068 // Verify the UI-only content is not present anywhere in the request
5069 let request_text = request
5070 .messages
5071 .iter()
5072 .flat_map(|msg| &msg.content)
5073 .filter_map(|content| match content {
5074 MessageContent::Text(text) => Some(text.as_str()),
5075 _ => None,
5076 })
5077 .collect::<String>();
5078
5079 assert!(
5080 !request_text.contains("UI-only message"),
5081 "UI-only message content should not be in the request"
5082 );
5083
5084 // Verify the thread still has all 3 messages (including UI-only)
5085 thread.read_with(cx, |thread, _| {
5086 assert_eq!(
5087 thread.messages().count(),
5088 3,
5089 "Thread should have 3 messages"
5090 );
5091 assert_eq!(
5092 thread.messages().filter(|m| m.ui_only).count(),
5093 1,
5094 "Thread should have 1 UI-only message"
5095 );
5096 });
5097
5098 // Verify that UI-only messages are not serialized
5099 let serialized = thread
5100 .update(cx, |thread, cx| thread.serialize(cx))
5101 .await
5102 .unwrap();
5103 assert_eq!(
5104 serialized.messages.len(),
5105 2,
5106 "Serialized thread should only have 2 messages (no UI-only)"
5107 );
5108 }
5109
5110 #[gpui::test]
5111 async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5112 init_test_settings(cx);
5113
5114 let project = create_test_project(cx, json!({})).await;
5115 let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5116
5117 // Create model that returns overloaded error
5118 let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5119
5120 // Insert a user message
5121 thread.update(cx, |thread, cx| {
5122 thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5123 });
5124
5125 // Start completion
5126 thread.update(cx, |thread, cx| {
5127 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5128 });
5129
5130 cx.run_until_parked();
5131
5132 // Verify retry was scheduled by checking for retry message
5133 let has_retry_message = thread.read_with(cx, |thread, _| {
5134 thread.messages.iter().any(|m| {
5135 m.ui_only
5136 && m.segments.iter().any(|s| {
5137 if let MessageSegment::Text(text) = s {
5138 text.contains("Retrying") && text.contains("seconds")
5139 } else {
5140 false
5141 }
5142 })
5143 })
5144 });
5145 assert!(has_retry_message, "Should have scheduled a retry");
5146
5147 // Cancel the completion before the retry happens
5148 thread.update(cx, |thread, cx| {
5149 thread.cancel_last_completion(None, cx);
5150 });
5151
5152 cx.run_until_parked();
5153
5154 // The retry should not have happened - no pending completions
5155 let fake_model = model.as_fake();
5156 assert_eq!(
5157 fake_model.pending_completions().len(),
5158 0,
5159 "Should have no pending completions after cancellation"
5160 );
5161
5162 // Verify the retry was cancelled by checking retry state
5163 thread.read_with(cx, |thread, _| {
5164 if let Some(retry_state) = &thread.retry_state {
5165 panic!(
5166 "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5167 retry_state.attempt, retry_state.max_attempts, retry_state.intent
5168 );
5169 }
5170 });
5171 }
5172
5173 fn test_summarize_error(
5174 model: &Arc<dyn LanguageModel>,
5175 thread: &Entity<Thread>,
5176 cx: &mut TestAppContext,
5177 ) {
5178 thread.update(cx, |thread, cx| {
5179 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5180 thread.send_to_model(
5181 model.clone(),
5182 CompletionIntent::ThreadSummarization,
5183 None,
5184 cx,
5185 );
5186 });
5187
5188 let fake_model = model.as_fake();
5189 simulate_successful_response(&fake_model, cx);
5190
5191 thread.read_with(cx, |thread, _| {
5192 assert!(matches!(thread.summary(), ThreadSummary::Generating));
5193 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5194 });
5195
5196 // Simulate summary request ending
5197 cx.run_until_parked();
5198 fake_model.end_last_completion_stream();
5199 cx.run_until_parked();
5200
5201 // State is set to Error and default message
5202 thread.read_with(cx, |thread, _| {
5203 assert!(matches!(thread.summary(), ThreadSummary::Error));
5204 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5205 });
5206 }
5207
5208 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5209 cx.run_until_parked();
5210 fake_model.stream_last_completion_response("Assistant response");
5211 fake_model.end_last_completion_stream();
5212 cx.run_until_parked();
5213 }
5214
5215 fn init_test_settings(cx: &mut TestAppContext) {
5216 cx.update(|cx| {
5217 let settings_store = SettingsStore::test(cx);
5218 cx.set_global(settings_store);
5219 language::init(cx);
5220 Project::init_settings(cx);
5221 AgentSettings::register(cx);
5222 prompt_store::init(cx);
5223 thread_store::init(cx);
5224 workspace::init_settings(cx);
5225 language_model::init_settings(cx);
5226 ThemeSettings::register(cx);
5227 ToolRegistry::default_global(cx);
5228 assistant_tool::init(cx);
5229
5230 let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5231 http_client::FakeHttpClient::with_200_response(),
5232 "http://localhost".to_string(),
5233 None,
5234 ));
5235 assistant_tools::init(http_client, cx);
5236 });
5237 }
5238
5239 // Helper to create a test project with test files
5240 async fn create_test_project(
5241 cx: &mut TestAppContext,
5242 files: serde_json::Value,
5243 ) -> Entity<Project> {
5244 let fs = FakeFs::new(cx.executor());
5245 fs.insert_tree(path!("/test"), files).await;
5246 Project::test(fs, [path!("/test").as_ref()], cx).await
5247 }
5248
5249 async fn setup_test_environment(
5250 cx: &mut TestAppContext,
5251 project: Entity<Project>,
5252 ) -> (
5253 Entity<Workspace>,
5254 Entity<ThreadStore>,
5255 Entity<Thread>,
5256 Entity<ContextStore>,
5257 Arc<dyn LanguageModel>,
5258 ) {
5259 let (workspace, cx) =
5260 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5261
5262 let thread_store = cx
5263 .update(|_, cx| {
5264 ThreadStore::load(
5265 project.clone(),
5266 cx.new(|_| ToolWorkingSet::default()),
5267 None,
5268 Arc::new(PromptBuilder::new(None).unwrap()),
5269 cx,
5270 )
5271 })
5272 .await
5273 .unwrap();
5274
5275 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5276 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5277
5278 let provider = Arc::new(FakeLanguageModelProvider);
5279 let model = provider.test_model();
5280 let model: Arc<dyn LanguageModel> = Arc::new(model);
5281
5282 cx.update(|_, cx| {
5283 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5284 registry.set_default_model(
5285 Some(ConfiguredModel {
5286 provider: provider.clone(),
5287 model: model.clone(),
5288 }),
5289 cx,
5290 );
5291 registry.set_thread_summary_model(
5292 Some(ConfiguredModel {
5293 provider,
5294 model: model.clone(),
5295 }),
5296 cx,
5297 );
5298 })
5299 });
5300
5301 (workspace, thread_store, thread, context_store, model)
5302 }
5303
5304 async fn add_file_to_context(
5305 project: &Entity<Project>,
5306 context_store: &Entity<ContextStore>,
5307 path: &str,
5308 cx: &mut TestAppContext,
5309 ) -> Result<Entity<language::Buffer>> {
5310 let buffer_path = project
5311 .read_with(cx, |project, cx| project.find_project_path(path, cx))
5312 .unwrap();
5313
5314 let buffer = project
5315 .update(cx, |project, cx| {
5316 project.open_buffer(buffer_path.clone(), cx)
5317 })
5318 .await
5319 .unwrap();
5320
5321 context_store.update(cx, |context_store, cx| {
5322 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5323 });
5324
5325 Ok(buffer)
5326 }
5327}