1use crate::{
2 AgentThread, AgentThreadUserMessageChunk, MessageId, ThreadId,
3 agent_profile::AgentProfile,
4 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
5 thread_store::{SharedProjectContext, ThreadStore},
6};
7use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use client::{ModelRequestUsage, RequestUsage};
12use collections::{HashMap, HashSet};
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity, Window,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
26};
27use postage::stream::Stream as _;
28use project::{
29 Project,
30 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
31};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use std::{
37 io::Write,
38 ops::Range,
39 sync::Arc,
40 time::{Duration, Instant},
41};
42use thiserror::Error;
43use util::{ResultExt as _, post_inc};
44use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
45
46/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
47#[derive(Clone, Debug)]
48pub struct MessageCrease {
49 pub range: Range<usize>,
50 pub icon_path: SharedString,
51 pub label: SharedString,
52 /// None for a deserialized message, Some otherwise.
53 pub context: Option<AgentContextHandle>,
54}
55
56pub enum MessageTool {
57 Pending {
58 tool: Arc<dyn Tool>,
59 input: serde_json::Value,
60 },
61 NeedsConfirmation {
62 tool: Arc<dyn Tool>,
63 input_json: serde_json::Value,
64 confirm_tx: oneshot::Sender<bool>,
65 },
66 Confirmed {
67 card: AnyToolCard,
68 },
69 Declined {
70 tool: Arc<dyn Tool>,
71 input_json: serde_json::Value,
72 },
73}
74
75/// A message in a [`Thread`].
76pub struct Message {
77 pub id: MessageId,
78 pub role: Role,
79 pub thinking: String,
80 pub text: String,
81 pub tools: Vec<MessageTool>,
82 pub loaded_context: LoadedContext,
83 pub creases: Vec<MessageCrease>,
84 pub is_hidden: bool,
85 pub ui_only: bool,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89pub struct ProjectSnapshot {
90 pub worktree_snapshots: Vec<WorktreeSnapshot>,
91 pub unsaved_buffer_paths: Vec<String>,
92 pub timestamp: DateTime<Utc>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct WorktreeSnapshot {
97 pub worktree_path: String,
98 pub git_state: Option<GitState>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
102pub struct GitState {
103 pub remote_url: Option<String>,
104 pub head_sha: Option<String>,
105 pub current_branch: Option<String>,
106 pub diff: Option<String>,
107}
108
109#[derive(Clone, Debug)]
110pub struct ThreadCheckpoint {
111 message_id: MessageId,
112 git_checkpoint: GitStoreCheckpoint,
113}
114
115#[derive(Copy, Clone, Debug, PartialEq, Eq)]
116pub enum ThreadFeedback {
117 Positive,
118 Negative,
119}
120
121pub enum LastRestoreCheckpoint {
122 Pending {
123 message_id: MessageId,
124 },
125 Error {
126 message_id: MessageId,
127 error: String,
128 },
129}
130
131impl LastRestoreCheckpoint {
132 pub fn message_id(&self) -> MessageId {
133 match self {
134 LastRestoreCheckpoint::Pending { message_id } => *message_id,
135 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
136 }
137 }
138}
139
140#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
141pub enum DetailedSummaryState {
142 #[default]
143 NotGenerated,
144 Generating {
145 message_id: MessageId,
146 },
147 Generated {
148 text: SharedString,
149 message_id: MessageId,
150 },
151}
152
153impl DetailedSummaryState {
154 fn text(&self) -> Option<SharedString> {
155 if let Self::Generated { text, .. } = self {
156 Some(text.clone())
157 } else {
158 None
159 }
160 }
161}
162
163#[derive(Default, Debug)]
164pub struct TotalTokenUsage {
165 pub total: u64,
166 pub max: u64,
167}
168
169impl TotalTokenUsage {
170 pub fn ratio(&self) -> TokenUsageRatio {
171 #[cfg(debug_assertions)]
172 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
173 .unwrap_or("0.8".to_string())
174 .parse()
175 .unwrap();
176 #[cfg(not(debug_assertions))]
177 let warning_threshold: f32 = 0.8;
178
179 // When the maximum is unknown because there is no selected model,
180 // avoid showing the token limit warning.
181 if self.max == 0 {
182 TokenUsageRatio::Normal
183 } else if self.total >= self.max {
184 TokenUsageRatio::Exceeded
185 } else if self.total as f32 / self.max as f32 >= warning_threshold {
186 TokenUsageRatio::Warning
187 } else {
188 TokenUsageRatio::Normal
189 }
190 }
191
192 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
193 TotalTokenUsage {
194 total: self.total + tokens,
195 max: self.max,
196 }
197 }
198}
199
200#[derive(Debug, Default, PartialEq, Eq)]
201pub enum TokenUsageRatio {
202 #[default]
203 Normal,
204 Warning,
205 Exceeded,
206}
207
208#[derive(Debug, Clone, Copy)]
209pub enum QueueState {
210 Sending,
211 Queued { position: usize },
212 Started,
213}
214
215/// A thread of conversation with the LLM.
216pub struct Thread {
217 agent_thread: Arc<dyn AgentThread>,
218 title: ThreadTitle,
219 pending_send: Option<Task<Result<()>>>,
220 pending_summary: Task<Option<()>>,
221 detailed_summary_task: Task<Option<()>>,
222 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
223 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
224 completion_mode: agent_settings::CompletionMode,
225 messages: Vec<Message>,
226 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
227 project: Entity<Project>,
228 action_log: Entity<ActionLog>,
229 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
230 pending_checkpoint: Option<ThreadCheckpoint>,
231 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
232 request_token_usage: Vec<TokenUsage>,
233 cumulative_token_usage: TokenUsage,
234 exceeded_window_error: Option<ExceededWindowError>,
235 tool_use_limit_reached: bool,
236 // todo!(keep track of retries from the underlying agent)
237 feedback: Option<ThreadFeedback>,
238 message_feedback: HashMap<MessageId, ThreadFeedback>,
239 last_auto_capture_at: Option<Instant>,
240 last_received_chunk_at: Option<Instant>,
241}
242
243#[derive(Clone, Debug, PartialEq, Eq)]
244pub enum ThreadTitle {
245 Pending,
246 Generating,
247 Ready(SharedString),
248 Error,
249}
250
251impl ThreadTitle {
252 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
253
254 pub fn or_default(&self) -> SharedString {
255 self.unwrap_or(Self::DEFAULT)
256 }
257
258 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
259 self.ready().unwrap_or_else(|| message.into())
260 }
261
262 pub fn ready(&self) -> Option<SharedString> {
263 match self {
264 ThreadTitle::Ready(summary) => Some(summary.clone()),
265 ThreadTitle::Pending | ThreadTitle::Generating | ThreadTitle::Error => None,
266 }
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
271pub struct ExceededWindowError {
272 /// Model used when last message exceeded context window
273 model_id: LanguageModelId,
274 /// Token count including last message
275 token_count: u64,
276}
277
278impl Thread {
279 pub fn load(
280 agent_thread: Arc<dyn AgentThread>,
281 project: Entity<Project>,
282 cx: &mut Context<Self>,
283 ) -> Self {
284 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
285 Self {
286 agent_thread,
287 title: ThreadTitle::Pending,
288 pending_send: None,
289 pending_summary: Task::ready(None),
290 detailed_summary_task: Task::ready(None),
291 detailed_summary_tx,
292 detailed_summary_rx,
293 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
294 messages: todo!("read from agent"),
295 checkpoints_by_message: HashMap::default(),
296 project: project.clone(),
297 last_restore_checkpoint: None,
298 pending_checkpoint: None,
299 action_log: cx.new(|_| ActionLog::new(project.clone())),
300 initial_project_snapshot: {
301 let project_snapshot = Self::project_snapshot(project, cx);
302 cx.foreground_executor()
303 .spawn(async move { Some(project_snapshot.await) })
304 .shared()
305 },
306 request_token_usage: Vec::new(),
307 cumulative_token_usage: TokenUsage::default(),
308 exceeded_window_error: None,
309 tool_use_limit_reached: false,
310 feedback: None,
311 message_feedback: HashMap::default(),
312 last_auto_capture_at: None,
313 last_received_chunk_at: None,
314 }
315 }
316
317 pub fn id(&self) -> ThreadId {
318 self.agent_thread.id()
319 }
320
321 pub fn profile(&self) -> &AgentProfile {
322 todo!()
323 }
324
325 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
326 todo!()
327 // if &id != self.profile.id() {
328 // self.profile = AgentProfile::new(id, self.tools.clone());
329 // cx.emit(ThreadEvent::ProfileChanged);
330 // }
331 }
332
333 pub fn is_empty(&self) -> bool {
334 self.messages.is_empty()
335 }
336
337 pub fn project_context(&self) -> SharedProjectContext {
338 todo!()
339 // self.project_context.clone()
340 }
341
342 pub fn title(&self) -> &ThreadTitle {
343 &self.title
344 }
345
346 pub fn set_title(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
347 todo!()
348 // let current_summary = match &self.summary {
349 // ThreadSummary::Pending | ThreadSummary::Generating => return,
350 // ThreadSummary::Ready(summary) => summary,
351 // ThreadSummary::Error => &ThreadSummary::DEFAULT,
352 // };
353
354 // let mut new_summary = new_summary.into();
355
356 // if new_summary.is_empty() {
357 // new_summary = ThreadSummary::DEFAULT;
358 // }
359
360 // if current_summary != &new_summary {
361 // self.summary = ThreadSummary::Ready(new_summary);
362 // cx.emit(ThreadEvent::SummaryChanged);
363 // }
364 }
365
366 pub fn regenerate_summary(&self, cx: &mut Context<Self>) {
367 todo!()
368 // self.summarize(cx);
369 }
370
371 pub fn completion_mode(&self) -> CompletionMode {
372 self.completion_mode
373 }
374
375 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
376 self.completion_mode = mode;
377 }
378
379 pub fn message(&self, id: MessageId) -> Option<&Message> {
380 let index = self
381 .messages
382 .binary_search_by(|message| message.id.cmp(&id))
383 .ok()?;
384
385 self.messages.get(index)
386 }
387
388 pub fn messages(&self) -> &[Message] {
389 &self.messages
390 }
391
392 pub fn is_generating(&self) -> bool {
393 self.pending_send.is_some()
394 }
395
396 /// Indicates whether streaming of language model events is stale.
397 /// When `is_generating()` is false, this method returns `None`.
398 pub fn is_generation_stale(&self) -> Option<bool> {
399 const STALE_THRESHOLD: u128 = 250;
400
401 self.last_received_chunk_at
402 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
403 }
404
405 fn received_chunk(&mut self) {
406 self.last_received_chunk_at = Some(Instant::now());
407 }
408
409 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
410 self.checkpoints_by_message.get(&id).cloned()
411 }
412
413 pub fn restore_checkpoint(
414 &mut self,
415 checkpoint: ThreadCheckpoint,
416 cx: &mut Context<Self>,
417 ) -> Task<Result<()>> {
418 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
419 message_id: checkpoint.message_id,
420 });
421 cx.emit(ThreadEvent::CheckpointChanged);
422 cx.notify();
423
424 let git_store = self.project().read(cx).git_store().clone();
425 let restore = git_store.update(cx, |git_store, cx| {
426 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
427 });
428
429 cx.spawn(async move |this, cx| {
430 let result = restore.await;
431 this.update(cx, |this, cx| {
432 if let Err(err) = result.as_ref() {
433 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
434 message_id: checkpoint.message_id,
435 error: err.to_string(),
436 });
437 } else {
438 this.truncate(checkpoint.message_id, cx);
439 this.last_restore_checkpoint = None;
440 }
441 this.pending_checkpoint = None;
442 cx.emit(ThreadEvent::CheckpointChanged);
443 cx.notify();
444 })?;
445 result
446 })
447 }
448
449 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
450 let pending_checkpoint = if self.is_generating() {
451 return;
452 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
453 checkpoint
454 } else {
455 return;
456 };
457
458 self.finalize_checkpoint(pending_checkpoint, cx);
459 }
460
461 fn finalize_checkpoint(
462 &mut self,
463 pending_checkpoint: ThreadCheckpoint,
464 cx: &mut Context<Self>,
465 ) {
466 let git_store = self.project.read(cx).git_store().clone();
467 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
468 cx.spawn(async move |this, cx| match final_checkpoint.await {
469 Ok(final_checkpoint) => {
470 let equal = git_store
471 .update(cx, |store, cx| {
472 store.compare_checkpoints(
473 pending_checkpoint.git_checkpoint.clone(),
474 final_checkpoint.clone(),
475 cx,
476 )
477 })?
478 .await
479 .unwrap_or(false);
480
481 if !equal {
482 this.update(cx, |this, cx| {
483 this.insert_checkpoint(pending_checkpoint, cx)
484 })?;
485 }
486
487 Ok(())
488 }
489 Err(_) => this.update(cx, |this, cx| {
490 this.insert_checkpoint(pending_checkpoint, cx)
491 }),
492 })
493 .detach();
494 }
495
496 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
497 self.checkpoints_by_message
498 .insert(checkpoint.message_id, checkpoint);
499 cx.emit(ThreadEvent::CheckpointChanged);
500 cx.notify();
501 }
502
503 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
504 self.last_restore_checkpoint.as_ref()
505 }
506
507 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
508 todo!("call truncate on the agent");
509 let Some(message_ix) = self
510 .messages
511 .iter()
512 .rposition(|message| message.id == message_id)
513 else {
514 return;
515 };
516 for deleted_message in self.messages.drain(message_ix..) {
517 self.checkpoints_by_message.remove(&deleted_message.id);
518 }
519 cx.notify();
520 }
521
522 pub fn is_turn_end(&self, ix: usize) -> bool {
523 todo!()
524 // if self.messages.is_empty() {
525 // return false;
526 // }
527
528 // if !self.is_generating() && ix == self.messages.len() - 1 {
529 // return true;
530 // }
531
532 // let Some(message) = self.messages.get(ix) else {
533 // return false;
534 // };
535
536 // if message.role != Role::Assistant {
537 // return false;
538 // }
539
540 // self.messages
541 // .get(ix + 1)
542 // .and_then(|message| {
543 // self.message(message.id)
544 // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
545 // })
546 // .unwrap_or(false)
547 }
548
549 pub fn tool_use_limit_reached(&self) -> bool {
550 self.tool_use_limit_reached
551 }
552
553 /// Returns whether any pending tool uses may perform edits
554 pub fn has_pending_edit_tool_uses(&self) -> bool {
555 todo!()
556 }
557
558 // pub fn insert_user_message(
559 // &mut self,
560 // text: impl Into<String>,
561 // loaded_context: ContextLoadResult,
562 // git_checkpoint: Option<GitStoreCheckpoint>,
563 // creases: Vec<MessageCrease>,
564 // cx: &mut Context<Self>,
565 // ) -> AgentThreadMessageId {
566 // todo!("move this logic into send")
567 // if !loaded_context.referenced_buffers.is_empty() {
568 // self.action_log.update(cx, |log, cx| {
569 // for buffer in loaded_context.referenced_buffers {
570 // log.buffer_read(buffer, cx);
571 // }
572 // });
573 // }
574
575 // let message_id = self.insert_message(
576 // Role::User,
577 // vec![MessageSegment::Text(text.into())],
578 // loaded_context.loaded_context,
579 // creases,
580 // false,
581 // cx,
582 // );
583
584 // if let Some(git_checkpoint) = git_checkpoint {
585 // self.pending_checkpoint = Some(ThreadCheckpoint {
586 // message_id,
587 // git_checkpoint,
588 // });
589 // }
590
591 // self.auto_capture_telemetry(cx);
592
593 // message_id
594 // }
595
596 pub fn set_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
597 todo!()
598 }
599
600 pub fn model(&self) -> Option<ConfiguredModel> {
601 todo!()
602 }
603
604 pub fn send(
605 &mut self,
606 message: Vec<AgentThreadUserMessageChunk>,
607 window: &mut Window,
608 cx: &mut Context<Self>,
609 ) {
610 todo!()
611 }
612
613 pub fn resume(&mut self, window: &mut Window, cx: &mut Context<Self>) {
614 todo!()
615 }
616
617 pub fn edit(
618 &mut self,
619 message_id: MessageId,
620 message: Vec<AgentThreadUserMessageChunk>,
621 window: &mut Window,
622 cx: &mut Context<Self>,
623 ) {
624 todo!()
625 }
626
627 pub fn cancel(&mut self, window: &mut Window, cx: &mut Context<Self>) -> bool {
628 todo!()
629 }
630
631 // pub fn insert_invisible_continue_message(
632 // &mut self,
633 // cx: &mut Context<Self>,
634 // ) -> AgentThreadMessageId {
635 // let id = self.insert_message(
636 // Role::User,
637 // vec![MessageSegment::Text("Continue where you left off".into())],
638 // LoadedContext::default(),
639 // vec![],
640 // true,
641 // cx,
642 // );
643 // self.pending_checkpoint = None;
644
645 // id
646 // }
647
648 // pub fn insert_assistant_message(
649 // &mut self,
650 // segments: Vec<MessageSegment>,
651 // cx: &mut Context<Self>,
652 // ) -> AgentThreadMessageId {
653 // self.insert_message(
654 // Role::Assistant,
655 // segments,
656 // LoadedContext::default(),
657 // Vec::new(),
658 // false,
659 // cx,
660 // )
661 // }
662
663 // pub fn insert_message(
664 // &mut self,
665 // role: Role,
666 // segments: Vec<MessageSegment>,
667 // loaded_context: LoadedContext,
668 // creases: Vec<MessageCrease>,
669 // is_hidden: bool,
670 // cx: &mut Context<Self>,
671 // ) -> AgentThreadMessageId {
672 // let id = self.next_message_id.post_inc();
673 // self.messages.push(Message {
674 // id,
675 // role,
676 // segments,
677 // loaded_context,
678 // creases,
679 // is_hidden,
680 // ui_only: false,
681 // });
682 // self.touch_updated_at();
683 // cx.emit(ThreadEvent::MessageAdded(id));
684 // id
685 // }
686
687 // pub fn edit_message(
688 // &mut self,
689 // id: AgentThreadMessageId,
690 // new_role: Role,
691 // new_segments: Vec<MessageSegment>,
692 // creases: Vec<MessageCrease>,
693 // loaded_context: Option<LoadedContext>,
694 // checkpoint: Option<GitStoreCheckpoint>,
695 // cx: &mut Context<Self>,
696 // ) -> bool {
697 // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
698 // return false;
699 // };
700 // message.role = new_role;
701 // message.segments = new_segments;
702 // message.creases = creases;
703 // if let Some(context) = loaded_context {
704 // message.loaded_context = context;
705 // }
706 // if let Some(git_checkpoint) = checkpoint {
707 // self.checkpoints_by_message.insert(
708 // id,
709 // ThreadCheckpoint {
710 // message_id: id,
711 // git_checkpoint,
712 // },
713 // );
714 // }
715 // self.touch_updated_at();
716 // cx.emit(ThreadEvent::MessageEdited(id));
717 // true
718 // }
719
720 /// Returns the representation of this [`Thread`] in a textual form.
721 ///
722 /// This is the representation we use when attaching a thread as context to another thread.
723 pub fn text(&self) -> String {
724 let mut text = String::new();
725
726 for message in &self.messages {
727 text.push_str(match message.role {
728 language_model::Role::User => "User:",
729 language_model::Role::Assistant => "Agent:",
730 language_model::Role::System => "System:",
731 });
732 text.push('\n');
733
734 text.push_str("<think>");
735 text.push_str(&message.thinking);
736 text.push_str("</think>");
737 text.push_str(&message.text);
738
739 // todo!('what about tools?');
740
741 text.push('\n');
742 }
743
744 text
745 }
746
747 pub fn used_tools_since_last_user_message(&self) -> bool {
748 todo!()
749 // for message in self.messages.iter().rev() {
750 // if self.tool_use.message_has_tool_results(message.id) {
751 // return true;
752 // } else if message.role == Role::User {
753 // return false;
754 // }
755 // }
756
757 // false
758 }
759
760 pub fn start_generating_detailed_summary_if_needed(
761 &mut self,
762 thread_store: WeakEntity<ThreadStore>,
763 cx: &mut Context<Self>,
764 ) {
765 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
766 return;
767 };
768
769 match &*self.detailed_summary_rx.borrow() {
770 DetailedSummaryState::Generating { message_id, .. }
771 | DetailedSummaryState::Generated { message_id, .. }
772 if *message_id == last_message_id =>
773 {
774 // Already up-to-date
775 return;
776 }
777 _ => {}
778 }
779
780 let summary = self.agent_thread.summary();
781
782 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
783 message_id: last_message_id,
784 };
785
786 // Replace the detailed summarization task if there is one, cancelling it. It would probably
787 // be better to allow the old task to complete, but this would require logic for choosing
788 // which result to prefer (the old task could complete after the new one, resulting in a
789 // stale summary).
790 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
791 let Some(summary) = summary.await.log_err() else {
792 thread
793 .update(cx, |thread, _cx| {
794 *thread.detailed_summary_tx.borrow_mut() =
795 DetailedSummaryState::NotGenerated;
796 })
797 .ok()?;
798 return None;
799 };
800
801 thread
802 .update(cx, |thread, _cx| {
803 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
804 text: summary.into(),
805 message_id: last_message_id,
806 };
807 })
808 .ok()?;
809
810 Some(())
811 });
812 }
813
814 pub async fn wait_for_detailed_summary_or_text(
815 this: &Entity<Self>,
816 cx: &mut AsyncApp,
817 ) -> Option<SharedString> {
818 let mut detailed_summary_rx = this
819 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
820 .ok()?;
821 loop {
822 match detailed_summary_rx.recv().await? {
823 DetailedSummaryState::Generating { .. } => {}
824 DetailedSummaryState::NotGenerated => {
825 return this.read_with(cx, |this, _cx| this.text().into()).ok();
826 }
827 DetailedSummaryState::Generated { text, .. } => return Some(text),
828 }
829 }
830 }
831
832 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
833 self.detailed_summary_rx
834 .borrow()
835 .text()
836 .unwrap_or_else(|| self.text().into())
837 }
838
839 pub fn is_generating_detailed_summary(&self) -> bool {
840 matches!(
841 &*self.detailed_summary_rx.borrow(),
842 DetailedSummaryState::Generating { .. }
843 )
844 }
845
846 pub fn feedback(&self) -> Option<ThreadFeedback> {
847 self.feedback
848 }
849
850 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
851 self.message_feedback.get(&message_id).copied()
852 }
853
854 pub fn report_message_feedback(
855 &mut self,
856 message_id: MessageId,
857 feedback: ThreadFeedback,
858 cx: &mut Context<Self>,
859 ) -> Task<Result<()>> {
860 todo!()
861 // if self.message_feedback.get(&message_id) == Some(&feedback) {
862 // return Task::ready(Ok(()));
863 // }
864
865 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
866 // let serialized_thread = self.serialize(cx);
867 // let thread_id = self.id().clone();
868 // let client = self.project.read(cx).client();
869
870 // let enabled_tool_names: Vec<String> = self
871 // .profile
872 // .enabled_tools(cx)
873 // .iter()
874 // .map(|tool| tool.name())
875 // .collect();
876
877 // self.message_feedback.insert(message_id, feedback);
878
879 // cx.notify();
880
881 // let message_content = self
882 // .message(message_id)
883 // .map(|msg| msg.to_string())
884 // .unwrap_or_default();
885
886 // cx.background_spawn(async move {
887 // let final_project_snapshot = final_project_snapshot.await;
888 // let serialized_thread = serialized_thread.await?;
889 // let thread_data =
890 // serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
891
892 // let rating = match feedback {
893 // ThreadFeedback::Positive => "positive",
894 // ThreadFeedback::Negative => "negative",
895 // };
896 // telemetry::event!(
897 // "Assistant Thread Rated",
898 // rating,
899 // thread_id,
900 // enabled_tool_names,
901 // message_id = message_id,
902 // message_content,
903 // thread_data,
904 // final_project_snapshot
905 // );
906 // client.telemetry().flush_events().await;
907
908 // Ok(())
909 // })
910 }
911
912 pub fn report_feedback(
913 &mut self,
914 feedback: ThreadFeedback,
915 cx: &mut Context<Self>,
916 ) -> Task<Result<()>> {
917 todo!()
918 // let last_assistant_message_id = self
919 // .messages
920 // .iter()
921 // .rev()
922 // .find(|msg| msg.role == Role::Assistant)
923 // .map(|msg| msg.id);
924
925 // if let Some(message_id) = last_assistant_message_id {
926 // self.report_message_feedback(message_id, feedback, cx)
927 // } else {
928 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
929 // let serialized_thread = self.serialize(cx);
930 // let thread_id = self.id().clone();
931 // let client = self.project.read(cx).client();
932 // self.feedback = Some(feedback);
933 // cx.notify();
934
935 // cx.background_spawn(async move {
936 // let final_project_snapshot = final_project_snapshot.await;
937 // let serialized_thread = serialized_thread.await?;
938 // let thread_data = serde_json::to_value(serialized_thread)
939 // .unwrap_or_else(|_| serde_json::Value::Null);
940
941 // let rating = match feedback {
942 // ThreadFeedback::Positive => "positive",
943 // ThreadFeedback::Negative => "negative",
944 // };
945 // telemetry::event!(
946 // "Assistant Thread Rated",
947 // rating,
948 // thread_id,
949 // thread_data,
950 // final_project_snapshot
951 // );
952 // client.telemetry().flush_events().await;
953
954 // Ok(())
955 // })
956 // }
957 }
958
959 /// Create a snapshot of the current project state including git information and unsaved buffers.
960 fn project_snapshot(
961 project: Entity<Project>,
962 cx: &mut Context<Self>,
963 ) -> Task<Arc<ProjectSnapshot>> {
964 let git_store = project.read(cx).git_store().clone();
965 let worktree_snapshots: Vec<_> = project
966 .read(cx)
967 .visible_worktrees(cx)
968 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
969 .collect();
970
971 cx.spawn(async move |_, cx| {
972 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
973
974 let mut unsaved_buffers = Vec::new();
975 cx.update(|app_cx| {
976 let buffer_store = project.read(app_cx).buffer_store();
977 for buffer_handle in buffer_store.read(app_cx).buffers() {
978 let buffer = buffer_handle.read(app_cx);
979 if buffer.is_dirty() {
980 if let Some(file) = buffer.file() {
981 let path = file.path().to_string_lossy().to_string();
982 unsaved_buffers.push(path);
983 }
984 }
985 }
986 })
987 .ok();
988
989 Arc::new(ProjectSnapshot {
990 worktree_snapshots,
991 unsaved_buffer_paths: unsaved_buffers,
992 timestamp: Utc::now(),
993 })
994 })
995 }
996
997 fn worktree_snapshot(
998 worktree: Entity<project::Worktree>,
999 git_store: Entity<GitStore>,
1000 cx: &App,
1001 ) -> Task<WorktreeSnapshot> {
1002 cx.spawn(async move |cx| {
1003 // Get worktree path and snapshot
1004 let worktree_info = cx.update(|app_cx| {
1005 let worktree = worktree.read(app_cx);
1006 let path = worktree.abs_path().to_string_lossy().to_string();
1007 let snapshot = worktree.snapshot();
1008 (path, snapshot)
1009 });
1010
1011 let Ok((worktree_path, _snapshot)) = worktree_info else {
1012 return WorktreeSnapshot {
1013 worktree_path: String::new(),
1014 git_state: None,
1015 };
1016 };
1017
1018 let git_state = git_store
1019 .update(cx, |git_store, cx| {
1020 git_store
1021 .repositories()
1022 .values()
1023 .find(|repo| {
1024 repo.read(cx)
1025 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1026 .is_some()
1027 })
1028 .cloned()
1029 })
1030 .ok()
1031 .flatten()
1032 .map(|repo| {
1033 repo.update(cx, |repo, _| {
1034 let current_branch =
1035 repo.branch.as_ref().map(|branch| branch.name().to_owned());
1036 repo.send_job(None, |state, _| async move {
1037 let RepositoryState::Local { backend, .. } = state else {
1038 return GitState {
1039 remote_url: None,
1040 head_sha: None,
1041 current_branch,
1042 diff: None,
1043 };
1044 };
1045
1046 let remote_url = backend.remote_url("origin");
1047 let head_sha = backend.head_sha().await;
1048 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1049
1050 GitState {
1051 remote_url,
1052 head_sha,
1053 current_branch,
1054 diff,
1055 }
1056 })
1057 })
1058 });
1059
1060 let git_state = match git_state {
1061 Some(git_state) => match git_state.ok() {
1062 Some(git_state) => git_state.await.ok(),
1063 None => None,
1064 },
1065 None => None,
1066 };
1067
1068 WorktreeSnapshot {
1069 worktree_path,
1070 git_state,
1071 }
1072 })
1073 }
1074
1075 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1076 todo!()
1077 // let mut markdown = Vec::new();
1078
1079 // let summary = self.summary().or_default();
1080 // writeln!(markdown, "# {summary}\n")?;
1081
1082 // for message in self.messages() {
1083 // writeln!(
1084 // markdown,
1085 // "## {role}\n",
1086 // role = match message.role {
1087 // Role::User => "User",
1088 // Role::Assistant => "Agent",
1089 // Role::System => "System",
1090 // }
1091 // )?;
1092
1093 // if !message.loaded_context.text.is_empty() {
1094 // writeln!(markdown, "{}", message.loaded_context.text)?;
1095 // }
1096
1097 // if !message.loaded_context.images.is_empty() {
1098 // writeln!(
1099 // markdown,
1100 // "\n{} images attached as context.\n",
1101 // message.loaded_context.images.len()
1102 // )?;
1103 // }
1104
1105 // for segment in &message.segments {
1106 // match segment {
1107 // MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1108 // MessageSegment::Thinking { text, .. } => {
1109 // writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1110 // }
1111 // MessageSegment::RedactedThinking(_) => {}
1112 // }
1113 // }
1114
1115 // for tool_use in self.tool_uses_for_message(message.id, cx) {
1116 // writeln!(
1117 // markdown,
1118 // "**Use Tool: {} ({})**",
1119 // tool_use.name, tool_use.id
1120 // )?;
1121 // writeln!(markdown, "```json")?;
1122 // writeln!(
1123 // markdown,
1124 // "{}",
1125 // serde_json::to_string_pretty(&tool_use.input)?
1126 // )?;
1127 // writeln!(markdown, "```")?;
1128 // }
1129
1130 // for tool_result in self.tool_results_for_message(message.id) {
1131 // write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
1132 // if tool_result.is_error {
1133 // write!(markdown, " (Error)")?;
1134 // }
1135
1136 // writeln!(markdown, "**\n")?;
1137 // match &tool_result.content {
1138 // LanguageModelToolResultContent::Text(text) => {
1139 // writeln!(markdown, "{text}")?;
1140 // }
1141 // LanguageModelToolResultContent::Image(image) => {
1142 // writeln!(markdown, "", image.source)?;
1143 // }
1144 // }
1145
1146 // if let Some(output) = tool_result.output.as_ref() {
1147 // writeln!(
1148 // markdown,
1149 // "\n\nDebug Output:\n\n```json\n{}\n```\n",
1150 // serde_json::to_string_pretty(output)?
1151 // )?;
1152 // }
1153 // }
1154 // }
1155
1156 // Ok(String::from_utf8_lossy(&markdown).to_string())
1157 }
1158
1159 pub fn keep_edits_in_range(
1160 &mut self,
1161 buffer: Entity<language::Buffer>,
1162 buffer_range: Range<language::Anchor>,
1163 cx: &mut Context<Self>,
1164 ) {
1165 self.action_log.update(cx, |action_log, cx| {
1166 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1167 });
1168 }
1169
1170 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1171 self.action_log
1172 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1173 }
1174
1175 pub fn reject_edits_in_ranges(
1176 &mut self,
1177 buffer: Entity<language::Buffer>,
1178 buffer_ranges: Vec<Range<language::Anchor>>,
1179 cx: &mut Context<Self>,
1180 ) -> Task<Result<()>> {
1181 self.action_log.update(cx, |action_log, cx| {
1182 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1183 })
1184 }
1185
1186 pub fn action_log(&self) -> &Entity<ActionLog> {
1187 &self.action_log
1188 }
1189
1190 pub fn project(&self) -> &Entity<Project> {
1191 &self.project
1192 }
1193
1194 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1195 todo!()
1196 // if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
1197 // return;
1198 // }
1199
1200 // let now = Instant::now();
1201 // if let Some(last) = self.last_auto_capture_at {
1202 // if now.duration_since(last).as_secs() < 10 {
1203 // return;
1204 // }
1205 // }
1206
1207 // self.last_auto_capture_at = Some(now);
1208
1209 // let thread_id = self.id().clone();
1210 // let github_login = self
1211 // .project
1212 // .read(cx)
1213 // .user_store()
1214 // .read(cx)
1215 // .current_user()
1216 // .map(|user| user.github_login.clone());
1217 // let client = self.project.read(cx).client();
1218 // let serialize_task = self.serialize(cx);
1219
1220 // cx.background_executor()
1221 // .spawn(async move {
1222 // if let Ok(serialized_thread) = serialize_task.await {
1223 // if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1224 // telemetry::event!(
1225 // "Agent Thread Auto-Captured",
1226 // thread_id = thread_id.to_string(),
1227 // thread_data = thread_data,
1228 // auto_capture_reason = "tracked_user",
1229 // github_login = github_login
1230 // );
1231
1232 // client.telemetry().flush_events().await;
1233 // }
1234 // }
1235 // })
1236 // .detach();
1237 }
1238
1239 pub fn cumulative_token_usage(&self) -> TokenUsage {
1240 self.cumulative_token_usage
1241 }
1242
1243 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
1244 todo!()
1245 // let Some(model) = self.configured_model.as_ref() else {
1246 // return TotalTokenUsage::default();
1247 // };
1248
1249 // let max = model.model.max_token_count();
1250
1251 // let index = self
1252 // .messages
1253 // .iter()
1254 // .position(|msg| msg.id == message_id)
1255 // .unwrap_or(0);
1256
1257 // if index == 0 {
1258 // return TotalTokenUsage { total: 0, max };
1259 // }
1260
1261 // let token_usage = &self
1262 // .request_token_usage
1263 // .get(index - 1)
1264 // .cloned()
1265 // .unwrap_or_default();
1266
1267 // TotalTokenUsage {
1268 // total: token_usage.total_tokens(),
1269 // max,
1270 // }
1271 }
1272
1273 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
1274 todo!()
1275 // let model = self.configured_model.as_ref()?;
1276
1277 // let max = model.model.max_token_count();
1278
1279 // if let Some(exceeded_error) = &self.exceeded_window_error {
1280 // if model.model.id() == exceeded_error.model_id {
1281 // return Some(TotalTokenUsage {
1282 // total: exceeded_error.token_count,
1283 // max,
1284 // });
1285 // }
1286 // }
1287
1288 // let total = self
1289 // .token_usage_at_last_message()
1290 // .unwrap_or_default()
1291 // .total_tokens();
1292
1293 // Some(TotalTokenUsage { total, max })
1294 }
1295
1296 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1297 self.request_token_usage
1298 .get(self.messages.len().saturating_sub(1))
1299 .or_else(|| self.request_token_usage.last())
1300 .cloned()
1301 }
1302
1303 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
1304 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
1305 self.request_token_usage
1306 .resize(self.messages.len(), placeholder);
1307
1308 if let Some(last) = self.request_token_usage.last_mut() {
1309 *last = token_usage;
1310 }
1311 }
1312
1313 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
1314 self.project.update(cx, |project, cx| {
1315 project.user_store().update(cx, |user_store, cx| {
1316 user_store.update_model_request_usage(
1317 ModelRequestUsage(RequestUsage {
1318 amount: amount as i32,
1319 limit,
1320 }),
1321 cx,
1322 )
1323 })
1324 });
1325 }
1326}
1327
1328#[derive(Debug, Clone, Error)]
1329pub enum ThreadError {
1330 #[error("Payment required")]
1331 PaymentRequired,
1332 #[error("Model request limit reached")]
1333 ModelRequestLimitReached { plan: Plan },
1334 #[error("Message {header}: {message}")]
1335 Message {
1336 header: SharedString,
1337 message: SharedString,
1338 },
1339}
1340
1341#[derive(Debug, Clone)]
1342pub enum ThreadEvent {
1343 ShowError(ThreadError),
1344 StreamedCompletion,
1345 ReceivedTextChunk,
1346 NewRequest,
1347 StreamedAssistantText(MessageId, String),
1348 StreamedAssistantThinking(MessageId, String),
1349 StreamedToolUse {
1350 tool_use_id: LanguageModelToolUseId,
1351 ui_text: Arc<str>,
1352 input: serde_json::Value,
1353 },
1354 MissingToolUse {
1355 tool_use_id: LanguageModelToolUseId,
1356 ui_text: Arc<str>,
1357 },
1358 InvalidToolInput {
1359 tool_use_id: LanguageModelToolUseId,
1360 ui_text: Arc<str>,
1361 invalid_input_json: Arc<str>,
1362 },
1363 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1364 MessageAdded(MessageId),
1365 MessageEdited(MessageId),
1366 MessageDeleted(MessageId),
1367 SummaryGenerated,
1368 SummaryChanged,
1369 CheckpointChanged,
1370 ToolConfirmationNeeded,
1371 ToolUseLimitReached,
1372 CancelEditing,
1373 CompletionCanceled,
1374 ProfileChanged,
1375 RetriesFailed {
1376 message: SharedString,
1377 },
1378}
1379
1380impl EventEmitter<ThreadEvent> for Thread {}
1381
1382struct PendingCompletion {
1383 id: usize,
1384 queue_state: QueueState,
1385 _task: Task<()>,
1386}
1387
1388/// Resolves tool name conflicts by ensuring all tool names are unique.
1389///
1390/// When multiple tools have the same name, this function applies the following rules:
1391/// 1. Native tools always keep their original name
1392/// 2. Context server tools get prefixed with their server ID and an underscore
1393/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
1394/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
1395///
1396/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
1397fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
1398 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
1399 let mut tool_name = tool.name();
1400 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1401 tool_name
1402 }
1403
1404 const MAX_TOOL_NAME_LENGTH: usize = 64;
1405
1406 let mut duplicated_tool_names = HashSet::default();
1407 let mut seen_tool_names = HashSet::default();
1408 for tool in tools {
1409 let tool_name = resolve_tool_name(tool);
1410 if seen_tool_names.contains(&tool_name) {
1411 debug_assert!(
1412 tool.source() != assistant_tool::ToolSource::Native,
1413 "There are two built-in tools with the same name: {}",
1414 tool_name
1415 );
1416 duplicated_tool_names.insert(tool_name);
1417 } else {
1418 seen_tool_names.insert(tool_name);
1419 }
1420 }
1421
1422 if duplicated_tool_names.is_empty() {
1423 return tools
1424 .into_iter()
1425 .map(|tool| (resolve_tool_name(tool), tool.clone()))
1426 .collect();
1427 }
1428
1429 tools
1430 .into_iter()
1431 .filter_map(|tool| {
1432 let mut tool_name = resolve_tool_name(tool);
1433 if !duplicated_tool_names.contains(&tool_name) {
1434 return Some((tool_name, tool.clone()));
1435 }
1436 match tool.source() {
1437 assistant_tool::ToolSource::Native => {
1438 // Built-in tools always keep their original name
1439 Some((tool_name, tool.clone()))
1440 }
1441 assistant_tool::ToolSource::ContextServer { id } => {
1442 // Context server tools are prefixed with the context server ID, and truncated if necessary
1443 tool_name.insert(0, '_');
1444 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
1445 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
1446 let mut id = id.to_string();
1447 id.truncate(len);
1448 tool_name.insert_str(0, &id);
1449 } else {
1450 tool_name.insert_str(0, &id);
1451 }
1452
1453 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1454
1455 if seen_tool_names.contains(&tool_name) {
1456 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
1457 None
1458 } else {
1459 Some((tool_name, tool.clone()))
1460 }
1461 }
1462 }
1463 })
1464 .collect()
1465}
1466
1467// #[cfg(test)]
1468// mod tests {
1469// use super::*;
1470// use crate::{
1471// context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
1472// };
1473
1474// // Test-specific constants
1475// const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
1476// use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
1477// use assistant_tool::ToolRegistry;
1478// use futures::StreamExt;
1479// use futures::future::BoxFuture;
1480// use futures::stream::BoxStream;
1481// use gpui::TestAppContext;
1482// use icons::IconName;
1483// use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1484// use language_model::{
1485// LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
1486// LanguageModelProviderName, LanguageModelToolChoice,
1487// };
1488// use parking_lot::Mutex;
1489// use project::{FakeFs, Project};
1490// use prompt_store::PromptBuilder;
1491// use serde_json::json;
1492// use settings::{Settings, SettingsStore};
1493// use std::sync::Arc;
1494// use std::time::Duration;
1495// use theme::ThemeSettings;
1496// use util::path;
1497// use workspace::Workspace;
1498
1499// #[gpui::test]
1500// async fn test_message_with_context(cx: &mut TestAppContext) {
1501// init_test_settings(cx);
1502
1503// let project = create_test_project(
1504// cx,
1505// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1506// )
1507// .await;
1508
1509// let (_workspace, _thread_store, thread, context_store, model) =
1510// setup_test_environment(cx, project.clone()).await;
1511
1512// add_file_to_context(&project, &context_store, "test/code.rs", cx)
1513// .await
1514// .unwrap();
1515
1516// let context =
1517// context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
1518// let loaded_context = cx
1519// .update(|cx| load_context(vec![context], &project, &None, cx))
1520// .await;
1521
1522// // Insert user message with context
1523// let message_id = thread.update(cx, |thread, cx| {
1524// thread.insert_user_message(
1525// "Please explain this code",
1526// loaded_context,
1527// None,
1528// Vec::new(),
1529// cx,
1530// )
1531// });
1532
1533// // Check content and context in message object
1534// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1535
1536// // Use different path format strings based on platform for the test
1537// #[cfg(windows)]
1538// let path_part = r"test\code.rs";
1539// #[cfg(not(windows))]
1540// let path_part = "test/code.rs";
1541
1542// let expected_context = format!(
1543// r#"
1544// <context>
1545// The following items were attached by the user. They are up-to-date and don't need to be re-read.
1546
1547// <files>
1548// ```rs {path_part}
1549// fn main() {{
1550// println!("Hello, world!");
1551// }}
1552// ```
1553// </files>
1554// </context>
1555// "#
1556// );
1557
1558// assert_eq!(message.role, Role::User);
1559// assert_eq!(message.segments.len(), 1);
1560// assert_eq!(
1561// message.segments[0],
1562// MessageSegment::Text("Please explain this code".to_string())
1563// );
1564// assert_eq!(message.loaded_context.text, expected_context);
1565
1566// // Check message in request
1567// let request = thread.update(cx, |thread, cx| {
1568// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1569// });
1570
1571// assert_eq!(request.messages.len(), 2);
1572// let expected_full_message = format!("{}Please explain this code", expected_context);
1573// assert_eq!(request.messages[1].string_contents(), expected_full_message);
1574// }
1575
1576// #[gpui::test]
1577// async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
1578// init_test_settings(cx);
1579
1580// let project = create_test_project(
1581// cx,
1582// json!({
1583// "file1.rs": "fn function1() {}\n",
1584// "file2.rs": "fn function2() {}\n",
1585// "file3.rs": "fn function3() {}\n",
1586// "file4.rs": "fn function4() {}\n",
1587// }),
1588// )
1589// .await;
1590
1591// let (_, _thread_store, thread, context_store, model) =
1592// setup_test_environment(cx, project.clone()).await;
1593
1594// // First message with context 1
1595// add_file_to_context(&project, &context_store, "test/file1.rs", cx)
1596// .await
1597// .unwrap();
1598// let new_contexts = context_store.update(cx, |store, cx| {
1599// store.new_context_for_thread(thread.read(cx), None)
1600// });
1601// assert_eq!(new_contexts.len(), 1);
1602// let loaded_context = cx
1603// .update(|cx| load_context(new_contexts, &project, &None, cx))
1604// .await;
1605// let message1_id = thread.update(cx, |thread, cx| {
1606// thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
1607// });
1608
1609// // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
1610// add_file_to_context(&project, &context_store, "test/file2.rs", cx)
1611// .await
1612// .unwrap();
1613// let new_contexts = context_store.update(cx, |store, cx| {
1614// store.new_context_for_thread(thread.read(cx), None)
1615// });
1616// assert_eq!(new_contexts.len(), 1);
1617// let loaded_context = cx
1618// .update(|cx| load_context(new_contexts, &project, &None, cx))
1619// .await;
1620// let message2_id = thread.update(cx, |thread, cx| {
1621// thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
1622// });
1623
1624// // Third message with all three contexts (contexts 1 and 2 should be skipped)
1625// //
1626// add_file_to_context(&project, &context_store, "test/file3.rs", cx)
1627// .await
1628// .unwrap();
1629// let new_contexts = context_store.update(cx, |store, cx| {
1630// store.new_context_for_thread(thread.read(cx), None)
1631// });
1632// assert_eq!(new_contexts.len(), 1);
1633// let loaded_context = cx
1634// .update(|cx| load_context(new_contexts, &project, &None, cx))
1635// .await;
1636// let message3_id = thread.update(cx, |thread, cx| {
1637// thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
1638// });
1639
1640// // Check what contexts are included in each message
1641// let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
1642// (
1643// thread.message(message1_id).unwrap().clone(),
1644// thread.message(message2_id).unwrap().clone(),
1645// thread.message(message3_id).unwrap().clone(),
1646// )
1647// });
1648
1649// // First message should include context 1
1650// assert!(message1.loaded_context.text.contains("file1.rs"));
1651
1652// // Second message should include only context 2 (not 1)
1653// assert!(!message2.loaded_context.text.contains("file1.rs"));
1654// assert!(message2.loaded_context.text.contains("file2.rs"));
1655
1656// // Third message should include only context 3 (not 1 or 2)
1657// assert!(!message3.loaded_context.text.contains("file1.rs"));
1658// assert!(!message3.loaded_context.text.contains("file2.rs"));
1659// assert!(message3.loaded_context.text.contains("file3.rs"));
1660
1661// // Check entire request to make sure all contexts are properly included
1662// let request = thread.update(cx, |thread, cx| {
1663// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1664// });
1665
1666// // The request should contain all 3 messages
1667// assert_eq!(request.messages.len(), 4);
1668
1669// // Check that the contexts are properly formatted in each message
1670// assert!(request.messages[1].string_contents().contains("file1.rs"));
1671// assert!(!request.messages[1].string_contents().contains("file2.rs"));
1672// assert!(!request.messages[1].string_contents().contains("file3.rs"));
1673
1674// assert!(!request.messages[2].string_contents().contains("file1.rs"));
1675// assert!(request.messages[2].string_contents().contains("file2.rs"));
1676// assert!(!request.messages[2].string_contents().contains("file3.rs"));
1677
1678// assert!(!request.messages[3].string_contents().contains("file1.rs"));
1679// assert!(!request.messages[3].string_contents().contains("file2.rs"));
1680// assert!(request.messages[3].string_contents().contains("file3.rs"));
1681
1682// add_file_to_context(&project, &context_store, "test/file4.rs", cx)
1683// .await
1684// .unwrap();
1685// let new_contexts = context_store.update(cx, |store, cx| {
1686// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1687// });
1688// assert_eq!(new_contexts.len(), 3);
1689// let loaded_context = cx
1690// .update(|cx| load_context(new_contexts, &project, &None, cx))
1691// .await
1692// .loaded_context;
1693
1694// assert!(!loaded_context.text.contains("file1.rs"));
1695// assert!(loaded_context.text.contains("file2.rs"));
1696// assert!(loaded_context.text.contains("file3.rs"));
1697// assert!(loaded_context.text.contains("file4.rs"));
1698
1699// let new_contexts = context_store.update(cx, |store, cx| {
1700// // Remove file4.rs
1701// store.remove_context(&loaded_context.contexts[2].handle(), cx);
1702// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1703// });
1704// assert_eq!(new_contexts.len(), 2);
1705// let loaded_context = cx
1706// .update(|cx| load_context(new_contexts, &project, &None, cx))
1707// .await
1708// .loaded_context;
1709
1710// assert!(!loaded_context.text.contains("file1.rs"));
1711// assert!(loaded_context.text.contains("file2.rs"));
1712// assert!(loaded_context.text.contains("file3.rs"));
1713// assert!(!loaded_context.text.contains("file4.rs"));
1714
1715// let new_contexts = context_store.update(cx, |store, cx| {
1716// // Remove file3.rs
1717// store.remove_context(&loaded_context.contexts[1].handle(), cx);
1718// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1719// });
1720// assert_eq!(new_contexts.len(), 1);
1721// let loaded_context = cx
1722// .update(|cx| load_context(new_contexts, &project, &None, cx))
1723// .await
1724// .loaded_context;
1725
1726// assert!(!loaded_context.text.contains("file1.rs"));
1727// assert!(loaded_context.text.contains("file2.rs"));
1728// assert!(!loaded_context.text.contains("file3.rs"));
1729// assert!(!loaded_context.text.contains("file4.rs"));
1730// }
1731
1732// #[gpui::test]
1733// async fn test_message_without_files(cx: &mut TestAppContext) {
1734// init_test_settings(cx);
1735
1736// let project = create_test_project(
1737// cx,
1738// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1739// )
1740// .await;
1741
1742// let (_, _thread_store, thread, _context_store, model) =
1743// setup_test_environment(cx, project.clone()).await;
1744
1745// // Insert user message without any context (empty context vector)
1746// let message_id = thread.update(cx, |thread, cx| {
1747// thread.insert_user_message(
1748// "What is the best way to learn Rust?",
1749// ContextLoadResult::default(),
1750// None,
1751// Vec::new(),
1752// cx,
1753// )
1754// });
1755
1756// // Check content and context in message object
1757// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1758
1759// // Context should be empty when no files are included
1760// assert_eq!(message.role, Role::User);
1761// assert_eq!(message.segments.len(), 1);
1762// assert_eq!(
1763// message.segments[0],
1764// MessageSegment::Text("What is the best way to learn Rust?".to_string())
1765// );
1766// assert_eq!(message.loaded_context.text, "");
1767
1768// // Check message in request
1769// let request = thread.update(cx, |thread, cx| {
1770// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1771// });
1772
1773// assert_eq!(request.messages.len(), 2);
1774// assert_eq!(
1775// request.messages[1].string_contents(),
1776// "What is the best way to learn Rust?"
1777// );
1778
1779// // Add second message, also without context
1780// let message2_id = thread.update(cx, |thread, cx| {
1781// thread.insert_user_message(
1782// "Are there any good books?",
1783// ContextLoadResult::default(),
1784// None,
1785// Vec::new(),
1786// cx,
1787// )
1788// });
1789
1790// let message2 =
1791// thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
1792// assert_eq!(message2.loaded_context.text, "");
1793
1794// // Check that both messages appear in the request
1795// let request = thread.update(cx, |thread, cx| {
1796// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1797// });
1798
1799// assert_eq!(request.messages.len(), 3);
1800// assert_eq!(
1801// request.messages[1].string_contents(),
1802// "What is the best way to learn Rust?"
1803// );
1804// assert_eq!(
1805// request.messages[2].string_contents(),
1806// "Are there any good books?"
1807// );
1808// }
1809
1810// #[gpui::test]
1811// async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
1812// init_test_settings(cx);
1813
1814// let project = create_test_project(
1815// cx,
1816// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1817// )
1818// .await;
1819
1820// let (_workspace, thread_store, thread, _context_store, _model) =
1821// setup_test_environment(cx, project.clone()).await;
1822
1823// // Check that we are starting with the default profile
1824// let profile = cx.read(|cx| thread.read(cx).profile.clone());
1825// let tool_set = cx.read(|cx| thread_store.read(cx).tools());
1826// assert_eq!(
1827// profile,
1828// AgentProfile::new(AgentProfileId::default(), tool_set)
1829// );
1830// }
1831
1832// #[gpui::test]
1833// async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
1834// init_test_settings(cx);
1835
1836// let project = create_test_project(
1837// cx,
1838// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1839// )
1840// .await;
1841
1842// let (_workspace, thread_store, thread, _context_store, _model) =
1843// setup_test_environment(cx, project.clone()).await;
1844
1845// // Profile gets serialized with default values
1846// let serialized = thread
1847// .update(cx, |thread, cx| thread.serialize(cx))
1848// .await
1849// .unwrap();
1850
1851// assert_eq!(serialized.profile, Some(AgentProfileId::default()));
1852
1853// let deserialized = cx.update(|cx| {
1854// thread.update(cx, |thread, cx| {
1855// Thread::deserialize(
1856// thread.id.clone(),
1857// serialized,
1858// thread.project.clone(),
1859// thread.tools.clone(),
1860// thread.prompt_builder.clone(),
1861// thread.project_context.clone(),
1862// None,
1863// cx,
1864// )
1865// })
1866// });
1867// let tool_set = cx.read(|cx| thread_store.read(cx).tools());
1868
1869// assert_eq!(
1870// deserialized.profile,
1871// AgentProfile::new(AgentProfileId::default(), tool_set)
1872// );
1873// }
1874
1875// #[gpui::test]
1876// async fn test_temperature_setting(cx: &mut TestAppContext) {
1877// init_test_settings(cx);
1878
1879// let project = create_test_project(
1880// cx,
1881// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1882// )
1883// .await;
1884
1885// let (_workspace, _thread_store, thread, _context_store, model) =
1886// setup_test_environment(cx, project.clone()).await;
1887
1888// // Both model and provider
1889// cx.update(|cx| {
1890// AgentSettings::override_global(
1891// AgentSettings {
1892// model_parameters: vec![LanguageModelParameters {
1893// provider: Some(model.provider_id().0.to_string().into()),
1894// model: Some(model.id().0.clone()),
1895// temperature: Some(0.66),
1896// }],
1897// ..AgentSettings::get_global(cx).clone()
1898// },
1899// cx,
1900// );
1901// });
1902
1903// let request = thread.update(cx, |thread, cx| {
1904// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1905// });
1906// assert_eq!(request.temperature, Some(0.66));
1907
1908// // Only model
1909// cx.update(|cx| {
1910// AgentSettings::override_global(
1911// AgentSettings {
1912// model_parameters: vec![LanguageModelParameters {
1913// provider: None,
1914// model: Some(model.id().0.clone()),
1915// temperature: Some(0.66),
1916// }],
1917// ..AgentSettings::get_global(cx).clone()
1918// },
1919// cx,
1920// );
1921// });
1922
1923// let request = thread.update(cx, |thread, cx| {
1924// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1925// });
1926// assert_eq!(request.temperature, Some(0.66));
1927
1928// // Only provider
1929// cx.update(|cx| {
1930// AgentSettings::override_global(
1931// AgentSettings {
1932// model_parameters: vec![LanguageModelParameters {
1933// provider: Some(model.provider_id().0.to_string().into()),
1934// model: None,
1935// temperature: Some(0.66),
1936// }],
1937// ..AgentSettings::get_global(cx).clone()
1938// },
1939// cx,
1940// );
1941// });
1942
1943// let request = thread.update(cx, |thread, cx| {
1944// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1945// });
1946// assert_eq!(request.temperature, Some(0.66));
1947
1948// // Same model name, different provider
1949// cx.update(|cx| {
1950// AgentSettings::override_global(
1951// AgentSettings {
1952// model_parameters: vec![LanguageModelParameters {
1953// provider: Some("anthropic".into()),
1954// model: Some(model.id().0.clone()),
1955// temperature: Some(0.66),
1956// }],
1957// ..AgentSettings::get_global(cx).clone()
1958// },
1959// cx,
1960// );
1961// });
1962
1963// let request = thread.update(cx, |thread, cx| {
1964// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1965// });
1966// assert_eq!(request.temperature, None);
1967// }
1968
1969// #[gpui::test]
1970// async fn test_thread_summary(cx: &mut TestAppContext) {
1971// init_test_settings(cx);
1972
1973// let project = create_test_project(cx, json!({})).await;
1974
1975// let (_, _thread_store, thread, _context_store, model) =
1976// setup_test_environment(cx, project.clone()).await;
1977
1978// // Initial state should be pending
1979// thread.read_with(cx, |thread, _| {
1980// assert!(matches!(thread.summary(), ThreadSummary::Pending));
1981// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
1982// });
1983
1984// // Manually setting the summary should not be allowed in this state
1985// thread.update(cx, |thread, cx| {
1986// thread.set_summary("This should not work", cx);
1987// });
1988
1989// thread.read_with(cx, |thread, _| {
1990// assert!(matches!(thread.summary(), ThreadSummary::Pending));
1991// });
1992
1993// // Send a message
1994// thread.update(cx, |thread, cx| {
1995// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
1996// thread.send_to_model(
1997// model.clone(),
1998// CompletionIntent::ThreadSummarization,
1999// None,
2000// cx,
2001// );
2002// });
2003
2004// let fake_model = model.as_fake();
2005// simulate_successful_response(&fake_model, cx);
2006
2007// // Should start generating summary when there are >= 2 messages
2008// thread.read_with(cx, |thread, _| {
2009// assert_eq!(*thread.summary(), ThreadSummary::Generating);
2010// });
2011
2012// // Should not be able to set the summary while generating
2013// thread.update(cx, |thread, cx| {
2014// thread.set_summary("This should not work either", cx);
2015// });
2016
2017// thread.read_with(cx, |thread, _| {
2018// assert!(matches!(thread.summary(), ThreadSummary::Generating));
2019// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
2020// });
2021
2022// cx.run_until_parked();
2023// fake_model.stream_last_completion_response("Brief");
2024// fake_model.stream_last_completion_response(" Introduction");
2025// fake_model.end_last_completion_stream();
2026// cx.run_until_parked();
2027
2028// // Summary should be set
2029// thread.read_with(cx, |thread, _| {
2030// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2031// assert_eq!(thread.summary().or_default(), "Brief Introduction");
2032// });
2033
2034// // Now we should be able to set a summary
2035// thread.update(cx, |thread, cx| {
2036// thread.set_summary("Brief Intro", cx);
2037// });
2038
2039// thread.read_with(cx, |thread, _| {
2040// assert_eq!(thread.summary().or_default(), "Brief Intro");
2041// });
2042
2043// // Test setting an empty summary (should default to DEFAULT)
2044// thread.update(cx, |thread, cx| {
2045// thread.set_summary("", cx);
2046// });
2047
2048// thread.read_with(cx, |thread, _| {
2049// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2050// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
2051// });
2052// }
2053
2054// #[gpui::test]
2055// async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
2056// init_test_settings(cx);
2057
2058// let project = create_test_project(cx, json!({})).await;
2059
2060// let (_, _thread_store, thread, _context_store, model) =
2061// setup_test_environment(cx, project.clone()).await;
2062
2063// test_summarize_error(&model, &thread, cx);
2064
2065// // Now we should be able to set a summary
2066// thread.update(cx, |thread, cx| {
2067// thread.set_summary("Brief Intro", cx);
2068// });
2069
2070// thread.read_with(cx, |thread, _| {
2071// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2072// assert_eq!(thread.summary().or_default(), "Brief Intro");
2073// });
2074// }
2075
2076// #[gpui::test]
2077// async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
2078// init_test_settings(cx);
2079
2080// let project = create_test_project(cx, json!({})).await;
2081
2082// let (_, _thread_store, thread, _context_store, model) =
2083// setup_test_environment(cx, project.clone()).await;
2084
2085// test_summarize_error(&model, &thread, cx);
2086
2087// // Sending another message should not trigger another summarize request
2088// thread.update(cx, |thread, cx| {
2089// thread.insert_user_message(
2090// "How are you?",
2091// ContextLoadResult::default(),
2092// None,
2093// vec![],
2094// cx,
2095// );
2096// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2097// });
2098
2099// let fake_model = model.as_fake();
2100// simulate_successful_response(&fake_model, cx);
2101
2102// thread.read_with(cx, |thread, _| {
2103// // State is still Error, not Generating
2104// assert!(matches!(thread.summary(), ThreadSummary::Error));
2105// });
2106
2107// // But the summarize request can be invoked manually
2108// thread.update(cx, |thread, cx| {
2109// thread.summarize(cx);
2110// });
2111
2112// thread.read_with(cx, |thread, _| {
2113// assert!(matches!(thread.summary(), ThreadSummary::Generating));
2114// });
2115
2116// cx.run_until_parked();
2117// fake_model.stream_last_completion_response("A successful summary");
2118// fake_model.end_last_completion_stream();
2119// cx.run_until_parked();
2120
2121// thread.read_with(cx, |thread, _| {
2122// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2123// assert_eq!(thread.summary().or_default(), "A successful summary");
2124// });
2125// }
2126
2127// #[gpui::test]
2128// fn test_resolve_tool_name_conflicts() {
2129// use assistant_tool::{Tool, ToolSource};
2130
2131// assert_resolve_tool_name_conflicts(
2132// vec![
2133// TestTool::new("tool1", ToolSource::Native),
2134// TestTool::new("tool2", ToolSource::Native),
2135// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2136// ],
2137// vec!["tool1", "tool2", "tool3"],
2138// );
2139
2140// assert_resolve_tool_name_conflicts(
2141// vec![
2142// TestTool::new("tool1", ToolSource::Native),
2143// TestTool::new("tool2", ToolSource::Native),
2144// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2145// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
2146// ],
2147// vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
2148// );
2149
2150// assert_resolve_tool_name_conflicts(
2151// vec![
2152// TestTool::new("tool1", ToolSource::Native),
2153// TestTool::new("tool2", ToolSource::Native),
2154// TestTool::new("tool3", ToolSource::Native),
2155// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2156// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
2157// ],
2158// vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
2159// );
2160
2161// // Test that tool with very long name is always truncated
2162// assert_resolve_tool_name_conflicts(
2163// vec![TestTool::new(
2164// "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
2165// ToolSource::Native,
2166// )],
2167// vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
2168// );
2169
2170// // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
2171// assert_resolve_tool_name_conflicts(
2172// vec![
2173// TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
2174// TestTool::new(
2175// "tool-with-very-very-very-long-name",
2176// ToolSource::ContextServer {
2177// id: "mcp-with-very-very-very-long-name".into(),
2178// },
2179// ),
2180// ],
2181// vec![
2182// "tool-with-very-very-very-long-name",
2183// "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
2184// ],
2185// );
2186
2187// fn assert_resolve_tool_name_conflicts(
2188// tools: Vec<TestTool>,
2189// expected: Vec<impl Into<String>>,
2190// ) {
2191// let tools: Vec<Arc<dyn Tool>> = tools
2192// .into_iter()
2193// .map(|t| Arc::new(t) as Arc<dyn Tool>)
2194// .collect();
2195// let tools = resolve_tool_name_conflicts(&tools);
2196// assert_eq!(tools.len(), expected.len());
2197// for (i, expected_name) in expected.into_iter().enumerate() {
2198// let expected_name = expected_name.into();
2199// let actual_name = &tools[i].0;
2200// assert_eq!(
2201// actual_name, &expected_name,
2202// "Expected '{}' got '{}' at index {}",
2203// expected_name, actual_name, i
2204// );
2205// }
2206// }
2207
2208// struct TestTool {
2209// name: String,
2210// source: ToolSource,
2211// }
2212
2213// impl TestTool {
2214// fn new(name: impl Into<String>, source: ToolSource) -> Self {
2215// Self {
2216// name: name.into(),
2217// source,
2218// }
2219// }
2220// }
2221
2222// impl Tool for TestTool {
2223// fn name(&self) -> String {
2224// self.name.clone()
2225// }
2226
2227// fn icon(&self) -> IconName {
2228// IconName::Ai
2229// }
2230
2231// fn may_perform_edits(&self) -> bool {
2232// false
2233// }
2234
2235// fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
2236// true
2237// }
2238
2239// fn source(&self) -> ToolSource {
2240// self.source.clone()
2241// }
2242
2243// fn description(&self) -> String {
2244// "Test tool".to_string()
2245// }
2246
2247// fn ui_text(&self, _input: &serde_json::Value) -> String {
2248// "Test tool".to_string()
2249// }
2250
2251// fn run(
2252// self: Arc<Self>,
2253// _input: serde_json::Value,
2254// _request: Arc<LanguageModelRequest>,
2255// _project: Entity<Project>,
2256// _action_log: Entity<ActionLog>,
2257// _model: Arc<dyn LanguageModel>,
2258// _window: Option<AnyWindowHandle>,
2259// _cx: &mut App,
2260// ) -> assistant_tool::ToolResult {
2261// assistant_tool::ToolResult {
2262// output: Task::ready(Err(anyhow::anyhow!("No content"))),
2263// card: None,
2264// }
2265// }
2266// }
2267// }
2268
2269// // Helper to create a model that returns errors
2270// enum TestError {
2271// Overloaded,
2272// InternalServerError,
2273// }
2274
2275// struct ErrorInjector {
2276// inner: Arc<FakeLanguageModel>,
2277// error_type: TestError,
2278// }
2279
2280// impl ErrorInjector {
2281// fn new(error_type: TestError) -> Self {
2282// Self {
2283// inner: Arc::new(FakeLanguageModel::default()),
2284// error_type,
2285// }
2286// }
2287// }
2288
2289// impl LanguageModel for ErrorInjector {
2290// fn id(&self) -> LanguageModelId {
2291// self.inner.id()
2292// }
2293
2294// fn name(&self) -> LanguageModelName {
2295// self.inner.name()
2296// }
2297
2298// fn provider_id(&self) -> LanguageModelProviderId {
2299// self.inner.provider_id()
2300// }
2301
2302// fn provider_name(&self) -> LanguageModelProviderName {
2303// self.inner.provider_name()
2304// }
2305
2306// fn supports_tools(&self) -> bool {
2307// self.inner.supports_tools()
2308// }
2309
2310// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2311// self.inner.supports_tool_choice(choice)
2312// }
2313
2314// fn supports_images(&self) -> bool {
2315// self.inner.supports_images()
2316// }
2317
2318// fn telemetry_id(&self) -> String {
2319// self.inner.telemetry_id()
2320// }
2321
2322// fn max_token_count(&self) -> u64 {
2323// self.inner.max_token_count()
2324// }
2325
2326// fn count_tokens(
2327// &self,
2328// request: LanguageModelRequest,
2329// cx: &App,
2330// ) -> BoxFuture<'static, Result<u64>> {
2331// self.inner.count_tokens(request, cx)
2332// }
2333
2334// fn stream_completion(
2335// &self,
2336// _request: LanguageModelRequest,
2337// _cx: &AsyncApp,
2338// ) -> BoxFuture<
2339// 'static,
2340// Result<
2341// BoxStream<
2342// 'static,
2343// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
2344// >,
2345// LanguageModelCompletionError,
2346// >,
2347// > {
2348// let error = match self.error_type {
2349// TestError::Overloaded => LanguageModelCompletionError::Overloaded,
2350// TestError::InternalServerError => {
2351// LanguageModelCompletionError::ApiInternalServerError
2352// }
2353// };
2354// async move {
2355// let stream = futures::stream::once(async move { Err(error) });
2356// Ok(stream.boxed())
2357// }
2358// .boxed()
2359// }
2360
2361// fn as_fake(&self) -> &FakeLanguageModel {
2362// &self.inner
2363// }
2364// }
2365
2366// #[gpui::test]
2367// async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
2368// init_test_settings(cx);
2369
2370// let project = create_test_project(cx, json!({})).await;
2371// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2372
2373// // Create model that returns overloaded error
2374// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2375
2376// // Insert a user message
2377// thread.update(cx, |thread, cx| {
2378// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2379// });
2380
2381// // Start completion
2382// thread.update(cx, |thread, cx| {
2383// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2384// });
2385
2386// cx.run_until_parked();
2387
2388// thread.read_with(cx, |thread, _| {
2389// assert!(thread.retry_state.is_some(), "Should have retry state");
2390// let retry_state = thread.retry_state.as_ref().unwrap();
2391// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2392// assert_eq!(
2393// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2394// "Should have default max attempts"
2395// );
2396// });
2397
2398// // Check that a retry message was added
2399// thread.read_with(cx, |thread, _| {
2400// let mut messages = thread.messages();
2401// assert!(
2402// messages.any(|msg| {
2403// msg.role == Role::System
2404// && msg.ui_only
2405// && msg.segments.iter().any(|seg| {
2406// if let MessageSegment::Text(text) = seg {
2407// text.contains("overloaded")
2408// && text
2409// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
2410// } else {
2411// false
2412// }
2413// })
2414// }),
2415// "Should have added a system retry message"
2416// );
2417// });
2418
2419// let retry_count = thread.update(cx, |thread, _| {
2420// thread
2421// .messages
2422// .iter()
2423// .filter(|m| {
2424// m.ui_only
2425// && m.segments.iter().any(|s| {
2426// if let MessageSegment::Text(text) = s {
2427// text.contains("Retrying") && text.contains("seconds")
2428// } else {
2429// false
2430// }
2431// })
2432// })
2433// .count()
2434// });
2435
2436// assert_eq!(retry_count, 1, "Should have one retry message");
2437// }
2438
2439// #[gpui::test]
2440// async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
2441// init_test_settings(cx);
2442
2443// let project = create_test_project(cx, json!({})).await;
2444// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2445
2446// // Create model that returns internal server error
2447// let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
2448
2449// // Insert a user message
2450// thread.update(cx, |thread, cx| {
2451// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2452// });
2453
2454// // Start completion
2455// thread.update(cx, |thread, cx| {
2456// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2457// });
2458
2459// cx.run_until_parked();
2460
2461// // Check retry state on thread
2462// thread.read_with(cx, |thread, _| {
2463// assert!(thread.retry_state.is_some(), "Should have retry state");
2464// let retry_state = thread.retry_state.as_ref().unwrap();
2465// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2466// assert_eq!(
2467// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2468// "Should have correct max attempts"
2469// );
2470// });
2471
2472// // Check that a retry message was added with provider name
2473// thread.read_with(cx, |thread, _| {
2474// let mut messages = thread.messages();
2475// assert!(
2476// messages.any(|msg| {
2477// msg.role == Role::System
2478// && msg.ui_only
2479// && msg.segments.iter().any(|seg| {
2480// if let MessageSegment::Text(text) = seg {
2481// text.contains("internal")
2482// && text.contains("Fake")
2483// && text
2484// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
2485// } else {
2486// false
2487// }
2488// })
2489// }),
2490// "Should have added a system retry message with provider name"
2491// );
2492// });
2493
2494// // Count retry messages
2495// let retry_count = thread.update(cx, |thread, _| {
2496// thread
2497// .messages
2498// .iter()
2499// .filter(|m| {
2500// m.ui_only
2501// && m.segments.iter().any(|s| {
2502// if let MessageSegment::Text(text) = s {
2503// text.contains("Retrying") && text.contains("seconds")
2504// } else {
2505// false
2506// }
2507// })
2508// })
2509// .count()
2510// });
2511
2512// assert_eq!(retry_count, 1, "Should have one retry message");
2513// }
2514
2515// #[gpui::test]
2516// async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
2517// init_test_settings(cx);
2518
2519// let project = create_test_project(cx, json!({})).await;
2520// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2521
2522// // Create model that returns overloaded error
2523// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2524
2525// // Insert a user message
2526// thread.update(cx, |thread, cx| {
2527// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2528// });
2529
2530// // Track retry events and completion count
2531// // Track completion events
2532// let completion_count = Arc::new(Mutex::new(0));
2533// let completion_count_clone = completion_count.clone();
2534
2535// let _subscription = thread.update(cx, |_, cx| {
2536// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2537// if let ThreadEvent::NewRequest = event {
2538// *completion_count_clone.lock() += 1;
2539// }
2540// })
2541// });
2542
2543// // First attempt
2544// thread.update(cx, |thread, cx| {
2545// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2546// });
2547// cx.run_until_parked();
2548
2549// // Should have scheduled first retry - count retry messages
2550// let retry_count = thread.update(cx, |thread, _| {
2551// thread
2552// .messages
2553// .iter()
2554// .filter(|m| {
2555// m.ui_only
2556// && m.segments.iter().any(|s| {
2557// if let MessageSegment::Text(text) = s {
2558// text.contains("Retrying") && text.contains("seconds")
2559// } else {
2560// false
2561// }
2562// })
2563// })
2564// .count()
2565// });
2566// assert_eq!(retry_count, 1, "Should have scheduled first retry");
2567
2568// // Check retry state
2569// thread.read_with(cx, |thread, _| {
2570// assert!(thread.retry_state.is_some(), "Should have retry state");
2571// let retry_state = thread.retry_state.as_ref().unwrap();
2572// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2573// });
2574
2575// // Advance clock for first retry
2576// cx.executor()
2577// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
2578// cx.run_until_parked();
2579
2580// // Should have scheduled second retry - count retry messages
2581// let retry_count = thread.update(cx, |thread, _| {
2582// thread
2583// .messages
2584// .iter()
2585// .filter(|m| {
2586// m.ui_only
2587// && m.segments.iter().any(|s| {
2588// if let MessageSegment::Text(text) = s {
2589// text.contains("Retrying") && text.contains("seconds")
2590// } else {
2591// false
2592// }
2593// })
2594// })
2595// .count()
2596// });
2597// assert_eq!(retry_count, 2, "Should have scheduled second retry");
2598
2599// // Check retry state updated
2600// thread.read_with(cx, |thread, _| {
2601// assert!(thread.retry_state.is_some(), "Should have retry state");
2602// let retry_state = thread.retry_state.as_ref().unwrap();
2603// assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
2604// assert_eq!(
2605// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2606// "Should have correct max attempts"
2607// );
2608// });
2609
2610// // Advance clock for second retry (exponential backoff)
2611// cx.executor()
2612// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
2613// cx.run_until_parked();
2614
2615// // Should have scheduled third retry
2616// // Count all retry messages now
2617// let retry_count = thread.update(cx, |thread, _| {
2618// thread
2619// .messages
2620// .iter()
2621// .filter(|m| {
2622// m.ui_only
2623// && m.segments.iter().any(|s| {
2624// if let MessageSegment::Text(text) = s {
2625// text.contains("Retrying") && text.contains("seconds")
2626// } else {
2627// false
2628// }
2629// })
2630// })
2631// .count()
2632// });
2633// assert_eq!(
2634// retry_count, MAX_RETRY_ATTEMPTS as usize,
2635// "Should have scheduled third retry"
2636// );
2637
2638// // Check retry state updated
2639// thread.read_with(cx, |thread, _| {
2640// assert!(thread.retry_state.is_some(), "Should have retry state");
2641// let retry_state = thread.retry_state.as_ref().unwrap();
2642// assert_eq!(
2643// retry_state.attempt, MAX_RETRY_ATTEMPTS,
2644// "Should be at max retry attempt"
2645// );
2646// assert_eq!(
2647// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2648// "Should have correct max attempts"
2649// );
2650// });
2651
2652// // Advance clock for third retry (exponential backoff)
2653// cx.executor()
2654// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
2655// cx.run_until_parked();
2656
2657// // No more retries should be scheduled after clock was advanced.
2658// let retry_count = thread.update(cx, |thread, _| {
2659// thread
2660// .messages
2661// .iter()
2662// .filter(|m| {
2663// m.ui_only
2664// && m.segments.iter().any(|s| {
2665// if let MessageSegment::Text(text) = s {
2666// text.contains("Retrying") && text.contains("seconds")
2667// } else {
2668// false
2669// }
2670// })
2671// })
2672// .count()
2673// });
2674// assert_eq!(
2675// retry_count, MAX_RETRY_ATTEMPTS as usize,
2676// "Should not exceed max retries"
2677// );
2678
2679// // Final completion count should be initial + max retries
2680// assert_eq!(
2681// *completion_count.lock(),
2682// (MAX_RETRY_ATTEMPTS + 1) as usize,
2683// "Should have made initial + max retry attempts"
2684// );
2685// }
2686
2687// #[gpui::test]
2688// async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
2689// init_test_settings(cx);
2690
2691// let project = create_test_project(cx, json!({})).await;
2692// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2693
2694// // Create model that returns overloaded error
2695// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2696
2697// // Insert a user message
2698// thread.update(cx, |thread, cx| {
2699// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2700// });
2701
2702// // Track events
2703// let retries_failed = Arc::new(Mutex::new(false));
2704// let retries_failed_clone = retries_failed.clone();
2705
2706// let _subscription = thread.update(cx, |_, cx| {
2707// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2708// if let ThreadEvent::RetriesFailed { .. } = event {
2709// *retries_failed_clone.lock() = true;
2710// }
2711// })
2712// });
2713
2714// // Start initial completion
2715// thread.update(cx, |thread, cx| {
2716// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2717// });
2718// cx.run_until_parked();
2719
2720// // Advance through all retries
2721// for i in 0..MAX_RETRY_ATTEMPTS {
2722// let delay = if i == 0 {
2723// BASE_RETRY_DELAY_SECS
2724// } else {
2725// BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
2726// };
2727// cx.executor().advance_clock(Duration::from_secs(delay));
2728// cx.run_until_parked();
2729// }
2730
2731// // After the 3rd retry is scheduled, we need to wait for it to execute and fail
2732// // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
2733// let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
2734// cx.executor()
2735// .advance_clock(Duration::from_secs(final_delay));
2736// cx.run_until_parked();
2737
2738// let retry_count = thread.update(cx, |thread, _| {
2739// thread
2740// .messages
2741// .iter()
2742// .filter(|m| {
2743// m.ui_only
2744// && m.segments.iter().any(|s| {
2745// if let MessageSegment::Text(text) = s {
2746// text.contains("Retrying") && text.contains("seconds")
2747// } else {
2748// false
2749// }
2750// })
2751// })
2752// .count()
2753// });
2754
2755// // After max retries, should emit RetriesFailed event
2756// assert_eq!(
2757// retry_count, MAX_RETRY_ATTEMPTS as usize,
2758// "Should have attempted max retries"
2759// );
2760// assert!(
2761// *retries_failed.lock(),
2762// "Should emit RetriesFailed event after max retries exceeded"
2763// );
2764
2765// // Retry state should be cleared
2766// thread.read_with(cx, |thread, _| {
2767// assert!(
2768// thread.retry_state.is_none(),
2769// "Retry state should be cleared after max retries"
2770// );
2771
2772// // Verify we have the expected number of retry messages
2773// let retry_messages = thread
2774// .messages
2775// .iter()
2776// .filter(|msg| msg.ui_only && msg.role == Role::System)
2777// .count();
2778// assert_eq!(
2779// retry_messages, MAX_RETRY_ATTEMPTS as usize,
2780// "Should have one retry message per attempt"
2781// );
2782// });
2783// }
2784
2785// #[gpui::test]
2786// async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
2787// init_test_settings(cx);
2788
2789// let project = create_test_project(cx, json!({})).await;
2790// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2791
2792// // We'll use a wrapper to switch behavior after first failure
2793// struct RetryTestModel {
2794// inner: Arc<FakeLanguageModel>,
2795// failed_once: Arc<Mutex<bool>>,
2796// }
2797
2798// impl LanguageModel for RetryTestModel {
2799// fn id(&self) -> LanguageModelId {
2800// self.inner.id()
2801// }
2802
2803// fn name(&self) -> LanguageModelName {
2804// self.inner.name()
2805// }
2806
2807// fn provider_id(&self) -> LanguageModelProviderId {
2808// self.inner.provider_id()
2809// }
2810
2811// fn provider_name(&self) -> LanguageModelProviderName {
2812// self.inner.provider_name()
2813// }
2814
2815// fn supports_tools(&self) -> bool {
2816// self.inner.supports_tools()
2817// }
2818
2819// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2820// self.inner.supports_tool_choice(choice)
2821// }
2822
2823// fn supports_images(&self) -> bool {
2824// self.inner.supports_images()
2825// }
2826
2827// fn telemetry_id(&self) -> String {
2828// self.inner.telemetry_id()
2829// }
2830
2831// fn max_token_count(&self) -> u64 {
2832// self.inner.max_token_count()
2833// }
2834
2835// fn count_tokens(
2836// &self,
2837// request: LanguageModelRequest,
2838// cx: &App,
2839// ) -> BoxFuture<'static, Result<u64>> {
2840// self.inner.count_tokens(request, cx)
2841// }
2842
2843// fn stream_completion(
2844// &self,
2845// request: LanguageModelRequest,
2846// cx: &AsyncApp,
2847// ) -> BoxFuture<
2848// 'static,
2849// Result<
2850// BoxStream<
2851// 'static,
2852// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
2853// >,
2854// LanguageModelCompletionError,
2855// >,
2856// > {
2857// if !*self.failed_once.lock() {
2858// *self.failed_once.lock() = true;
2859// // Return error on first attempt
2860// let stream = futures::stream::once(async move {
2861// Err(LanguageModelCompletionError::Overloaded)
2862// });
2863// async move { Ok(stream.boxed()) }.boxed()
2864// } else {
2865// // Succeed on retry
2866// self.inner.stream_completion(request, cx)
2867// }
2868// }
2869
2870// fn as_fake(&self) -> &FakeLanguageModel {
2871// &self.inner
2872// }
2873// }
2874
2875// let model = Arc::new(RetryTestModel {
2876// inner: Arc::new(FakeLanguageModel::default()),
2877// failed_once: Arc::new(Mutex::new(false)),
2878// });
2879
2880// // Insert a user message
2881// thread.update(cx, |thread, cx| {
2882// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2883// });
2884
2885// // Track message deletions
2886// // Track when retry completes successfully
2887// let retry_completed = Arc::new(Mutex::new(false));
2888// let retry_completed_clone = retry_completed.clone();
2889
2890// let _subscription = thread.update(cx, |_, cx| {
2891// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2892// if let ThreadEvent::StreamedCompletion = event {
2893// *retry_completed_clone.lock() = true;
2894// }
2895// })
2896// });
2897
2898// // Start completion
2899// thread.update(cx, |thread, cx| {
2900// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2901// });
2902// cx.run_until_parked();
2903
2904// // Get the retry message ID
2905// let retry_message_id = thread.read_with(cx, |thread, _| {
2906// thread
2907// .messages()
2908// .find(|msg| msg.role == Role::System && msg.ui_only)
2909// .map(|msg| msg.id)
2910// .expect("Should have a retry message")
2911// });
2912
2913// // Wait for retry
2914// cx.executor()
2915// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
2916// cx.run_until_parked();
2917
2918// // Stream some successful content
2919// let fake_model = model.as_fake();
2920// // After the retry, there should be a new pending completion
2921// let pending = fake_model.pending_completions();
2922// assert!(
2923// !pending.is_empty(),
2924// "Should have a pending completion after retry"
2925// );
2926// fake_model.stream_completion_response(&pending[0], "Success!");
2927// fake_model.end_completion_stream(&pending[0]);
2928// cx.run_until_parked();
2929
2930// // Check that the retry completed successfully
2931// assert!(
2932// *retry_completed.lock(),
2933// "Retry should have completed successfully"
2934// );
2935
2936// // Retry message should still exist but be marked as ui_only
2937// thread.read_with(cx, |thread, _| {
2938// let retry_msg = thread
2939// .message(retry_message_id)
2940// .expect("Retry message should still exist");
2941// assert!(retry_msg.ui_only, "Retry message should be ui_only");
2942// assert_eq!(
2943// retry_msg.role,
2944// Role::System,
2945// "Retry message should have System role"
2946// );
2947// });
2948// }
2949
2950// #[gpui::test]
2951// async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
2952// init_test_settings(cx);
2953
2954// let project = create_test_project(cx, json!({})).await;
2955// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2956
2957// // Create a model that fails once then succeeds
2958// struct FailOnceModel {
2959// inner: Arc<FakeLanguageModel>,
2960// failed_once: Arc<Mutex<bool>>,
2961// }
2962
2963// impl LanguageModel for FailOnceModel {
2964// fn id(&self) -> LanguageModelId {
2965// self.inner.id()
2966// }
2967
2968// fn name(&self) -> LanguageModelName {
2969// self.inner.name()
2970// }
2971
2972// fn provider_id(&self) -> LanguageModelProviderId {
2973// self.inner.provider_id()
2974// }
2975
2976// fn provider_name(&self) -> LanguageModelProviderName {
2977// self.inner.provider_name()
2978// }
2979
2980// fn supports_tools(&self) -> bool {
2981// self.inner.supports_tools()
2982// }
2983
2984// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2985// self.inner.supports_tool_choice(choice)
2986// }
2987
2988// fn supports_images(&self) -> bool {
2989// self.inner.supports_images()
2990// }
2991
2992// fn telemetry_id(&self) -> String {
2993// self.inner.telemetry_id()
2994// }
2995
2996// fn max_token_count(&self) -> u64 {
2997// self.inner.max_token_count()
2998// }
2999
3000// fn count_tokens(
3001// &self,
3002// request: LanguageModelRequest,
3003// cx: &App,
3004// ) -> BoxFuture<'static, Result<u64>> {
3005// self.inner.count_tokens(request, cx)
3006// }
3007
3008// fn stream_completion(
3009// &self,
3010// request: LanguageModelRequest,
3011// cx: &AsyncApp,
3012// ) -> BoxFuture<
3013// 'static,
3014// Result<
3015// BoxStream<
3016// 'static,
3017// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
3018// >,
3019// LanguageModelCompletionError,
3020// >,
3021// > {
3022// if !*self.failed_once.lock() {
3023// *self.failed_once.lock() = true;
3024// // Return error on first attempt
3025// let stream = futures::stream::once(async move {
3026// Err(LanguageModelCompletionError::Overloaded)
3027// });
3028// async move { Ok(stream.boxed()) }.boxed()
3029// } else {
3030// // Succeed on retry
3031// self.inner.stream_completion(request, cx)
3032// }
3033// }
3034// }
3035
3036// let fail_once_model = Arc::new(FailOnceModel {
3037// inner: Arc::new(FakeLanguageModel::default()),
3038// failed_once: Arc::new(Mutex::new(false)),
3039// });
3040
3041// // Insert a user message
3042// thread.update(cx, |thread, cx| {
3043// thread.insert_user_message(
3044// "Test message",
3045// ContextLoadResult::default(),
3046// None,
3047// vec![],
3048// cx,
3049// );
3050// });
3051
3052// // Start completion with fail-once model
3053// thread.update(cx, |thread, cx| {
3054// thread.send_to_model(
3055// fail_once_model.clone(),
3056// CompletionIntent::UserPrompt,
3057// None,
3058// cx,
3059// );
3060// });
3061
3062// cx.run_until_parked();
3063
3064// // Verify retry state exists after first failure
3065// thread.read_with(cx, |thread, _| {
3066// assert!(
3067// thread.retry_state.is_some(),
3068// "Should have retry state after failure"
3069// );
3070// });
3071
3072// // Wait for retry delay
3073// cx.executor()
3074// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
3075// cx.run_until_parked();
3076
3077// // The retry should now use our FailOnceModel which should succeed
3078// // We need to help the FakeLanguageModel complete the stream
3079// let inner_fake = fail_once_model.inner.clone();
3080
3081// // Wait a bit for the retry to start
3082// cx.run_until_parked();
3083
3084// // Check for pending completions and complete them
3085// if let Some(pending) = inner_fake.pending_completions().first() {
3086// inner_fake.stream_completion_response(pending, "Success!");
3087// inner_fake.end_completion_stream(pending);
3088// }
3089// cx.run_until_parked();
3090
3091// thread.read_with(cx, |thread, _| {
3092// assert!(
3093// thread.retry_state.is_none(),
3094// "Retry state should be cleared after successful completion"
3095// );
3096
3097// let has_assistant_message = thread
3098// .messages
3099// .iter()
3100// .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
3101// assert!(
3102// has_assistant_message,
3103// "Should have an assistant message after successful retry"
3104// );
3105// });
3106// }
3107
3108// #[gpui::test]
3109// async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
3110// init_test_settings(cx);
3111
3112// let project = create_test_project(cx, json!({})).await;
3113// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3114
3115// // Create a model that returns rate limit error with retry_after
3116// struct RateLimitModel {
3117// inner: Arc<FakeLanguageModel>,
3118// }
3119
3120// impl LanguageModel for RateLimitModel {
3121// fn id(&self) -> LanguageModelId {
3122// self.inner.id()
3123// }
3124
3125// fn name(&self) -> LanguageModelName {
3126// self.inner.name()
3127// }
3128
3129// fn provider_id(&self) -> LanguageModelProviderId {
3130// self.inner.provider_id()
3131// }
3132
3133// fn provider_name(&self) -> LanguageModelProviderName {
3134// self.inner.provider_name()
3135// }
3136
3137// fn supports_tools(&self) -> bool {
3138// self.inner.supports_tools()
3139// }
3140
3141// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
3142// self.inner.supports_tool_choice(choice)
3143// }
3144
3145// fn supports_images(&self) -> bool {
3146// self.inner.supports_images()
3147// }
3148
3149// fn telemetry_id(&self) -> String {
3150// self.inner.telemetry_id()
3151// }
3152
3153// fn max_token_count(&self) -> u64 {
3154// self.inner.max_token_count()
3155// }
3156
3157// fn count_tokens(
3158// &self,
3159// request: LanguageModelRequest,
3160// cx: &App,
3161// ) -> BoxFuture<'static, Result<u64>> {
3162// self.inner.count_tokens(request, cx)
3163// }
3164
3165// fn stream_completion(
3166// &self,
3167// _request: LanguageModelRequest,
3168// _cx: &AsyncApp,
3169// ) -> BoxFuture<
3170// 'static,
3171// Result<
3172// BoxStream<
3173// 'static,
3174// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
3175// >,
3176// LanguageModelCompletionError,
3177// >,
3178// > {
3179// async move {
3180// let stream = futures::stream::once(async move {
3181// Err(LanguageModelCompletionError::RateLimitExceeded {
3182// retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
3183// })
3184// });
3185// Ok(stream.boxed())
3186// }
3187// .boxed()
3188// }
3189
3190// fn as_fake(&self) -> &FakeLanguageModel {
3191// &self.inner
3192// }
3193// }
3194
3195// let model = Arc::new(RateLimitModel {
3196// inner: Arc::new(FakeLanguageModel::default()),
3197// });
3198
3199// // Insert a user message
3200// thread.update(cx, |thread, cx| {
3201// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3202// });
3203
3204// // Start completion
3205// thread.update(cx, |thread, cx| {
3206// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3207// });
3208
3209// cx.run_until_parked();
3210
3211// let retry_count = thread.update(cx, |thread, _| {
3212// thread
3213// .messages
3214// .iter()
3215// .filter(|m| {
3216// m.ui_only
3217// && m.segments.iter().any(|s| {
3218// if let MessageSegment::Text(text) = s {
3219// text.contains("rate limit exceeded")
3220// } else {
3221// false
3222// }
3223// })
3224// })
3225// .count()
3226// });
3227// assert_eq!(retry_count, 1, "Should have scheduled one retry");
3228
3229// thread.read_with(cx, |thread, _| {
3230// assert!(
3231// thread.retry_state.is_none(),
3232// "Rate limit errors should not set retry_state"
3233// );
3234// });
3235
3236// // Verify we have one retry message
3237// thread.read_with(cx, |thread, _| {
3238// let retry_messages = thread
3239// .messages
3240// .iter()
3241// .filter(|msg| {
3242// msg.ui_only
3243// && msg.segments.iter().any(|seg| {
3244// if let MessageSegment::Text(text) = seg {
3245// text.contains("rate limit exceeded")
3246// } else {
3247// false
3248// }
3249// })
3250// })
3251// .count();
3252// assert_eq!(
3253// retry_messages, 1,
3254// "Should have one rate limit retry message"
3255// );
3256// });
3257
3258// // Check that retry message doesn't include attempt count
3259// thread.read_with(cx, |thread, _| {
3260// let retry_message = thread
3261// .messages
3262// .iter()
3263// .find(|msg| msg.role == Role::System && msg.ui_only)
3264// .expect("Should have a retry message");
3265
3266// // Check that the message doesn't contain attempt count
3267// if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
3268// assert!(
3269// !text.contains("attempt"),
3270// "Rate limit retry message should not contain attempt count"
3271// );
3272// assert!(
3273// text.contains(&format!(
3274// "Retrying in {} seconds",
3275// TEST_RATE_LIMIT_RETRY_SECS
3276// )),
3277// "Rate limit retry message should contain retry delay"
3278// );
3279// }
3280// });
3281// }
3282
3283// #[gpui::test]
3284// async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
3285// init_test_settings(cx);
3286
3287// let project = create_test_project(cx, json!({})).await;
3288// let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
3289
3290// // Insert a regular user message
3291// thread.update(cx, |thread, cx| {
3292// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3293// });
3294
3295// // Insert a UI-only message (like our retry notifications)
3296// thread.update(cx, |thread, cx| {
3297// let id = thread.next_message_id.post_inc();
3298// thread.messages.push(Message {
3299// id,
3300// role: Role::System,
3301// segments: vec![MessageSegment::Text(
3302// "This is a UI-only message that should not be sent to the model".to_string(),
3303// )],
3304// loaded_context: LoadedContext::default(),
3305// creases: Vec::new(),
3306// is_hidden: true,
3307// ui_only: true,
3308// });
3309// cx.emit(ThreadEvent::MessageAdded(id));
3310// });
3311
3312// // Insert another regular message
3313// thread.update(cx, |thread, cx| {
3314// thread.insert_user_message(
3315// "How are you?",
3316// ContextLoadResult::default(),
3317// None,
3318// vec![],
3319// cx,
3320// );
3321// });
3322
3323// // Generate the completion request
3324// let request = thread.update(cx, |thread, cx| {
3325// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3326// });
3327
3328// // Verify that the request only contains non-UI-only messages
3329// // Should have system prompt + 2 user messages, but not the UI-only message
3330// let user_messages: Vec<_> = request
3331// .messages
3332// .iter()
3333// .filter(|msg| msg.role == Role::User)
3334// .collect();
3335// assert_eq!(
3336// user_messages.len(),
3337// 2,
3338// "Should have exactly 2 user messages"
3339// );
3340
3341// // Verify the UI-only content is not present anywhere in the request
3342// let request_text = request
3343// .messages
3344// .iter()
3345// .flat_map(|msg| &msg.content)
3346// .filter_map(|content| match content {
3347// MessageContent::Text(text) => Some(text.as_str()),
3348// _ => None,
3349// })
3350// .collect::<String>();
3351
3352// assert!(
3353// !request_text.contains("UI-only message"),
3354// "UI-only message content should not be in the request"
3355// );
3356
3357// // Verify the thread still has all 3 messages (including UI-only)
3358// thread.read_with(cx, |thread, _| {
3359// assert_eq!(
3360// thread.messages().count(),
3361// 3,
3362// "Thread should have 3 messages"
3363// );
3364// assert_eq!(
3365// thread.messages().filter(|m| m.ui_only).count(),
3366// 1,
3367// "Thread should have 1 UI-only message"
3368// );
3369// });
3370
3371// // Verify that UI-only messages are not serialized
3372// let serialized = thread
3373// .update(cx, |thread, cx| thread.serialize(cx))
3374// .await
3375// .unwrap();
3376// assert_eq!(
3377// serialized.messages.len(),
3378// 2,
3379// "Serialized thread should only have 2 messages (no UI-only)"
3380// );
3381// }
3382
3383// #[gpui::test]
3384// async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
3385// init_test_settings(cx);
3386
3387// let project = create_test_project(cx, json!({})).await;
3388// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3389
3390// // Create model that returns overloaded error
3391// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
3392
3393// // Insert a user message
3394// thread.update(cx, |thread, cx| {
3395// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3396// });
3397
3398// // Start completion
3399// thread.update(cx, |thread, cx| {
3400// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3401// });
3402
3403// cx.run_until_parked();
3404
3405// // Verify retry was scheduled by checking for retry message
3406// let has_retry_message = thread.read_with(cx, |thread, _| {
3407// thread.messages.iter().any(|m| {
3408// m.ui_only
3409// && m.segments.iter().any(|s| {
3410// if let MessageSegment::Text(text) = s {
3411// text.contains("Retrying") && text.contains("seconds")
3412// } else {
3413// false
3414// }
3415// })
3416// })
3417// });
3418// assert!(has_retry_message, "Should have scheduled a retry");
3419
3420// // Cancel the completion before the retry happens
3421// thread.update(cx, |thread, cx| {
3422// thread.cancel_last_completion(None, cx);
3423// });
3424
3425// cx.run_until_parked();
3426
3427// // The retry should not have happened - no pending completions
3428// let fake_model = model.as_fake();
3429// assert_eq!(
3430// fake_model.pending_completions().len(),
3431// 0,
3432// "Should have no pending completions after cancellation"
3433// );
3434
3435// // Verify the retry was cancelled by checking retry state
3436// thread.read_with(cx, |thread, _| {
3437// if let Some(retry_state) = &thread.retry_state {
3438// panic!(
3439// "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
3440// retry_state.attempt, retry_state.max_attempts, retry_state.intent
3441// );
3442// }
3443// });
3444// }
3445
3446// fn test_summarize_error(
3447// model: &Arc<dyn LanguageModel>,
3448// thread: &Entity<Thread>,
3449// cx: &mut TestAppContext,
3450// ) {
3451// thread.update(cx, |thread, cx| {
3452// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3453// thread.send_to_model(
3454// model.clone(),
3455// CompletionIntent::ThreadSummarization,
3456// None,
3457// cx,
3458// );
3459// });
3460
3461// let fake_model = model.as_fake();
3462// simulate_successful_response(&fake_model, cx);
3463
3464// thread.read_with(cx, |thread, _| {
3465// assert!(matches!(thread.summary(), ThreadSummary::Generating));
3466// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3467// });
3468
3469// // Simulate summary request ending
3470// cx.run_until_parked();
3471// fake_model.end_last_completion_stream();
3472// cx.run_until_parked();
3473
3474// // State is set to Error and default message
3475// thread.read_with(cx, |thread, _| {
3476// assert!(matches!(thread.summary(), ThreadSummary::Error));
3477// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3478// });
3479// }
3480
3481// fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3482// cx.run_until_parked();
3483// fake_model.stream_last_completion_response("Assistant response");
3484// fake_model.end_last_completion_stream();
3485// cx.run_until_parked();
3486// }
3487
3488// fn init_test_settings(cx: &mut TestAppContext) {
3489// cx.update(|cx| {
3490// let settings_store = SettingsStore::test(cx);
3491// cx.set_global(settings_store);
3492// language::init(cx);
3493// Project::init_settings(cx);
3494// AgentSettings::register(cx);
3495// prompt_store::init(cx);
3496// thread_store::init(cx);
3497// workspace::init_settings(cx);
3498// language_model::init_settings(cx);
3499// ThemeSettings::register(cx);
3500// ToolRegistry::default_global(cx);
3501// });
3502// }
3503
3504// // Helper to create a test project with test files
3505// async fn create_test_project(
3506// cx: &mut TestAppContext,
3507// files: serde_json::Value,
3508// ) -> Entity<Project> {
3509// let fs = FakeFs::new(cx.executor());
3510// fs.insert_tree(path!("/test"), files).await;
3511// Project::test(fs, [path!("/test").as_ref()], cx).await
3512// }
3513
3514// async fn setup_test_environment(
3515// cx: &mut TestAppContext,
3516// project: Entity<Project>,
3517// ) -> (
3518// Entity<Workspace>,
3519// Entity<ThreadStore>,
3520// Entity<Thread>,
3521// Entity<ContextStore>,
3522// Arc<dyn LanguageModel>,
3523// ) {
3524// let (workspace, cx) =
3525// cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3526
3527// let thread_store = cx
3528// .update(|_, cx| {
3529// ThreadStore::load(
3530// project.clone(),
3531// cx.new(|_| ToolWorkingSet::default()),
3532// None,
3533// Arc::new(PromptBuilder::new(None).unwrap()),
3534// cx,
3535// )
3536// })
3537// .await
3538// .unwrap();
3539
3540// let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3541// let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3542
3543// let provider = Arc::new(FakeLanguageModelProvider);
3544// let model = provider.test_model();
3545// let model: Arc<dyn LanguageModel> = Arc::new(model);
3546
3547// cx.update(|_, cx| {
3548// LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3549// registry.set_default_model(
3550// Some(ConfiguredModel {
3551// provider: provider.clone(),
3552// model: model.clone(),
3553// }),
3554// cx,
3555// );
3556// registry.set_thread_summary_model(
3557// Some(ConfiguredModel {
3558// provider,
3559// model: model.clone(),
3560// }),
3561// cx,
3562// );
3563// })
3564// });
3565
3566// (workspace, thread_store, thread, context_store, model)
3567// }
3568
3569// async fn add_file_to_context(
3570// project: &Entity<Project>,
3571// context_store: &Entity<ContextStore>,
3572// path: &str,
3573// cx: &mut TestAppContext,
3574// ) -> Result<Entity<language::Buffer>> {
3575// let buffer_path = project
3576// .read_with(cx, |project, cx| project.find_project_path(path, cx))
3577// .unwrap();
3578
3579// let buffer = project
3580// .update(cx, |project, cx| {
3581// project.open_buffer(buffer_path.clone(), cx)
3582// })
3583// .await
3584// .unwrap();
3585
3586// context_store.update(cx, |context_store, cx| {
3587// context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3588// });
3589
3590// Ok(buffer)
3591// }
3592// }