1use crate::{
2 AgentThread, AgentThreadUserMessageChunk, MessageId, ThreadId,
3 agent_profile::AgentProfile,
4 context::{AgentContextHandle, LoadedContext},
5 thread_store::{SharedProjectContext, ThreadStore},
6};
7use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::Result;
9use assistant_tool::{ActionLog, AnyToolCard, Tool};
10use chrono::{DateTime, Utc};
11use client::{ModelRequestUsage, RequestUsage};
12use collections::{HashMap, HashSet};
13use futures::{FutureExt, channel::oneshot, future::Shared};
14use git::repository::DiffType;
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
17 Window,
18};
19use language_model::{
20 ConfiguredModel, LanguageModelId, LanguageModelToolUseId, Role, StopReason, TokenUsage,
21};
22use markdown::Markdown;
23use postage::stream::Stream as _;
24use project::{
25 Project,
26 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
27};
28use proto::Plan;
29use serde::{Deserialize, Serialize};
30use settings::Settings;
31use std::{ops::Range, sync::Arc, time::Instant};
32use thiserror::Error;
33use util::ResultExt as _;
34use zed_llm_client::UsageLimit;
35
36/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
37#[derive(Clone, Debug)]
38pub struct MessageCrease {
39 pub range: Range<usize>,
40 pub icon_path: SharedString,
41 pub label: SharedString,
42 /// None for a deserialized message, Some otherwise.
43 pub context: Option<AgentContextHandle>,
44}
45
46pub enum MessageToolCall {
47 Pending {
48 tool: Arc<dyn Tool>,
49 input: serde_json::Value,
50 },
51 NeedsConfirmation {
52 tool: Arc<dyn Tool>,
53 input_json: serde_json::Value,
54 confirm_tx: oneshot::Sender<bool>,
55 },
56 Confirmed {
57 card: AnyToolCard,
58 },
59 Declined {
60 tool: Arc<dyn Tool>,
61 input_json: serde_json::Value,
62 },
63}
64
65/// A message in a [`Thread`].
66pub struct Message {
67 pub id: MessageId,
68 pub role: Role,
69 pub segments: Vec<MessageSegment>,
70 pub loaded_context: LoadedContext,
71 pub creases: Vec<MessageCrease>,
72 pub is_hidden: bool, // todo!("do we need this?")
73 pub ui_only: bool, // todo!("do we need this?")
74}
75
76pub enum Message {
77 User {
78 text: String,
79 creases: Vec<MessageCrease>,
80 },
81 Assistant {
82 segments: Vec<MessageSegment>,
83 },
84}
85
86pub enum MessageSegment {
87 Text(Entity<Markdown>),
88 Thinking(Entity<Markdown>),
89 ToolCall(MessageToolCall),
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93pub struct ProjectSnapshot {
94 pub worktree_snapshots: Vec<WorktreeSnapshot>,
95 pub unsaved_buffer_paths: Vec<String>,
96 pub timestamp: DateTime<Utc>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
100pub struct WorktreeSnapshot {
101 pub worktree_path: String,
102 pub git_state: Option<GitState>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct GitState {
107 pub remote_url: Option<String>,
108 pub head_sha: Option<String>,
109 pub current_branch: Option<String>,
110 pub diff: Option<String>,
111}
112
113#[derive(Clone, Debug)]
114pub struct ThreadCheckpoint {
115 message_id: MessageId,
116 git_checkpoint: GitStoreCheckpoint,
117}
118
119#[derive(Copy, Clone, Debug, PartialEq, Eq)]
120pub enum ThreadFeedback {
121 Positive,
122 Negative,
123}
124
125pub enum LastRestoreCheckpoint {
126 Pending {
127 message_id: MessageId,
128 },
129 Error {
130 message_id: MessageId,
131 error: String,
132 },
133}
134
135impl LastRestoreCheckpoint {
136 pub fn message_id(&self) -> MessageId {
137 match self {
138 LastRestoreCheckpoint::Pending { message_id } => *message_id,
139 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
140 }
141 }
142}
143
144#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
145pub enum DetailedSummaryState {
146 #[default]
147 NotGenerated,
148 Generating {
149 message_id: MessageId,
150 },
151 Generated {
152 text: SharedString,
153 message_id: MessageId,
154 },
155}
156
157impl DetailedSummaryState {
158 fn text(&self) -> Option<SharedString> {
159 if let Self::Generated { text, .. } = self {
160 Some(text.clone())
161 } else {
162 None
163 }
164 }
165}
166
167#[derive(Default, Debug)]
168pub struct TotalTokenUsage {
169 pub total: u64,
170 pub max: u64,
171}
172
173impl TotalTokenUsage {
174 pub fn ratio(&self) -> TokenUsageRatio {
175 #[cfg(debug_assertions)]
176 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
177 .unwrap_or("0.8".to_string())
178 .parse()
179 .unwrap();
180 #[cfg(not(debug_assertions))]
181 let warning_threshold: f32 = 0.8;
182
183 // When the maximum is unknown because there is no selected model,
184 // avoid showing the token limit warning.
185 if self.max == 0 {
186 TokenUsageRatio::Normal
187 } else if self.total >= self.max {
188 TokenUsageRatio::Exceeded
189 } else if self.total as f32 / self.max as f32 >= warning_threshold {
190 TokenUsageRatio::Warning
191 } else {
192 TokenUsageRatio::Normal
193 }
194 }
195
196 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
197 TotalTokenUsage {
198 total: self.total + tokens,
199 max: self.max,
200 }
201 }
202}
203
204#[derive(Debug, Default, PartialEq, Eq)]
205pub enum TokenUsageRatio {
206 #[default]
207 Normal,
208 Warning,
209 Exceeded,
210}
211
212#[derive(Debug, Clone, Copy)]
213pub enum QueueState {
214 Sending,
215 Queued { position: usize },
216 Started,
217}
218
219/// A thread of conversation with the LLM.
220pub struct Thread {
221 agent_thread: Arc<dyn AgentThread>,
222 title: ThreadTitle,
223 pending_send: Option<Task<Result<()>>>,
224 pending_summary: Task<Option<()>>,
225 detailed_summary_task: Task<Option<()>>,
226 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
227 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
228 completion_mode: agent_settings::CompletionMode,
229 messages: Vec<Message>,
230 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
231 project: Entity<Project>,
232 action_log: Entity<ActionLog>,
233 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
234 pending_checkpoint: Option<ThreadCheckpoint>,
235 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
236 request_token_usage: Vec<TokenUsage>,
237 cumulative_token_usage: TokenUsage,
238 exceeded_window_error: Option<ExceededWindowError>,
239 tool_use_limit_reached: bool,
240 // todo!(keep track of retries from the underlying agent)
241 feedback: Option<ThreadFeedback>,
242 message_feedback: HashMap<MessageId, ThreadFeedback>,
243 last_auto_capture_at: Option<Instant>,
244 last_received_chunk_at: Option<Instant>,
245}
246
247#[derive(Clone, Debug, PartialEq, Eq)]
248pub enum ThreadTitle {
249 Pending,
250 Generating,
251 Ready(SharedString),
252 Error,
253}
254
255impl ThreadTitle {
256 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
257
258 pub fn or_default(&self) -> SharedString {
259 self.unwrap_or(Self::DEFAULT)
260 }
261
262 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
263 self.ready().unwrap_or_else(|| message.into())
264 }
265
266 pub fn ready(&self) -> Option<SharedString> {
267 match self {
268 ThreadTitle::Ready(summary) => Some(summary.clone()),
269 ThreadTitle::Pending | ThreadTitle::Generating | ThreadTitle::Error => None,
270 }
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
275pub struct ExceededWindowError {
276 /// Model used when last message exceeded context window
277 model_id: LanguageModelId,
278 /// Token count including last message
279 token_count: u64,
280}
281
282impl Thread {
283 pub fn load(
284 agent_thread: Arc<dyn AgentThread>,
285 project: Entity<Project>,
286 cx: &mut Context<Self>,
287 ) -> Self {
288 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
289 Self {
290 agent_thread,
291 title: ThreadTitle::Pending,
292 pending_send: None,
293 pending_summary: Task::ready(None),
294 detailed_summary_task: Task::ready(None),
295 detailed_summary_tx,
296 detailed_summary_rx,
297 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
298 messages: todo!("read from agent"),
299 checkpoints_by_message: HashMap::default(),
300 project: project.clone(),
301 last_restore_checkpoint: None,
302 pending_checkpoint: None,
303 action_log: cx.new(|_| ActionLog::new(project.clone())),
304 initial_project_snapshot: {
305 let project_snapshot = Self::project_snapshot(project, cx);
306 cx.foreground_executor()
307 .spawn(async move { Some(project_snapshot.await) })
308 .shared()
309 },
310 request_token_usage: Vec::new(),
311 cumulative_token_usage: TokenUsage::default(),
312 exceeded_window_error: None,
313 tool_use_limit_reached: false,
314 feedback: None,
315 message_feedback: HashMap::default(),
316 last_auto_capture_at: None,
317 last_received_chunk_at: None,
318 }
319 }
320
321 pub fn id(&self) -> ThreadId {
322 self.agent_thread.id()
323 }
324
325 pub fn profile(&self) -> &AgentProfile {
326 todo!()
327 }
328
329 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
330 todo!()
331 // if &id != self.profile.id() {
332 // self.profile = AgentProfile::new(id, self.tools.clone());
333 // cx.emit(ThreadEvent::ProfileChanged);
334 // }
335 }
336
337 pub fn is_empty(&self) -> bool {
338 self.messages.is_empty()
339 }
340
341 pub fn project_context(&self) -> SharedProjectContext {
342 todo!()
343 // self.project_context.clone()
344 }
345
346 pub fn title(&self) -> &ThreadTitle {
347 &self.title
348 }
349
350 pub fn set_title(&mut self, new_title: impl Into<SharedString>, cx: &mut Context<Self>) {
351 todo!()
352 // let current_summary = match &self.summary {
353 // ThreadSummary::Pending | ThreadSummary::Generating => return,
354 // ThreadSummary::Ready(summary) => summary,
355 // ThreadSummary::Error => &ThreadSummary::DEFAULT,
356 // };
357
358 // let mut new_summary = new_summary.into();
359
360 // if new_summary.is_empty() {
361 // new_summary = ThreadSummary::DEFAULT;
362 // }
363
364 // if current_summary != &new_summary {
365 // self.summary = ThreadSummary::Ready(new_summary);
366 // cx.emit(ThreadEvent::SummaryChanged);
367 // }
368 }
369
370 pub fn regenerate_summary(&self, cx: &mut Context<Self>) {
371 todo!()
372 // self.summarize(cx);
373 }
374
375 pub fn completion_mode(&self) -> CompletionMode {
376 self.completion_mode
377 }
378
379 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
380 self.completion_mode = mode;
381 }
382
383 pub fn messages(&self) -> &[Message] {
384 &self.messages
385 }
386
387 pub fn is_generating(&self) -> bool {
388 self.pending_send.is_some()
389 }
390
391 /// Indicates whether streaming of language model events is stale.
392 /// When `is_generating()` is false, this method returns `None`.
393 pub fn is_generation_stale(&self) -> Option<bool> {
394 const STALE_THRESHOLD: u128 = 250;
395
396 self.last_received_chunk_at
397 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
398 }
399
400 fn received_chunk(&mut self) {
401 self.last_received_chunk_at = Some(Instant::now());
402 }
403
404 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
405 self.checkpoints_by_message.get(&id).cloned()
406 }
407
408 pub fn restore_checkpoint(
409 &mut self,
410 checkpoint: ThreadCheckpoint,
411 cx: &mut Context<Self>,
412 ) -> Task<Result<()>> {
413 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
414 message_id: checkpoint.message_id,
415 });
416 cx.emit(ThreadEvent::CheckpointChanged);
417 cx.notify();
418
419 let git_store = self.project().read(cx).git_store().clone();
420 let restore = git_store.update(cx, |git_store, cx| {
421 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
422 });
423
424 cx.spawn(async move |this, cx| {
425 let result = restore.await;
426 this.update(cx, |this, cx| {
427 if let Err(err) = result.as_ref() {
428 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
429 message_id: checkpoint.message_id,
430 error: err.to_string(),
431 });
432 } else {
433 this.truncate(checkpoint.message_id, cx);
434 this.last_restore_checkpoint = None;
435 }
436 this.pending_checkpoint = None;
437 cx.emit(ThreadEvent::CheckpointChanged);
438 cx.notify();
439 })?;
440 result
441 })
442 }
443
444 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
445 let pending_checkpoint = if self.is_generating() {
446 return;
447 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
448 checkpoint
449 } else {
450 return;
451 };
452
453 self.finalize_checkpoint(pending_checkpoint, cx);
454 }
455
456 fn finalize_checkpoint(
457 &mut self,
458 pending_checkpoint: ThreadCheckpoint,
459 cx: &mut Context<Self>,
460 ) {
461 let git_store = self.project.read(cx).git_store().clone();
462 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
463 cx.spawn(async move |this, cx| match final_checkpoint.await {
464 Ok(final_checkpoint) => {
465 let equal = git_store
466 .update(cx, |store, cx| {
467 store.compare_checkpoints(
468 pending_checkpoint.git_checkpoint.clone(),
469 final_checkpoint.clone(),
470 cx,
471 )
472 })?
473 .await
474 .unwrap_or(false);
475
476 if !equal {
477 this.update(cx, |this, cx| {
478 this.insert_checkpoint(pending_checkpoint, cx)
479 })?;
480 }
481
482 Ok(())
483 }
484 Err(_) => this.update(cx, |this, cx| {
485 this.insert_checkpoint(pending_checkpoint, cx)
486 }),
487 })
488 .detach();
489 }
490
491 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
492 self.checkpoints_by_message
493 .insert(checkpoint.message_id, checkpoint);
494 cx.emit(ThreadEvent::CheckpointChanged);
495 cx.notify();
496 }
497
498 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
499 self.last_restore_checkpoint.as_ref()
500 }
501
502 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
503 todo!("call truncate on the agent");
504 let Some(message_ix) = self
505 .messages
506 .iter()
507 .rposition(|message| message.id == message_id)
508 else {
509 return;
510 };
511 for deleted_message in self.messages.drain(message_ix..) {
512 self.checkpoints_by_message.remove(&deleted_message.id);
513 }
514 cx.notify();
515 }
516
517 pub fn is_turn_end(&self, ix: usize) -> bool {
518 todo!()
519 // if self.messages.is_empty() {
520 // return false;
521 // }
522
523 // if !self.is_generating() && ix == self.messages.len() - 1 {
524 // return true;
525 // }
526
527 // let Some(message) = self.messages.get(ix) else {
528 // return false;
529 // };
530
531 // if message.role != Role::Assistant {
532 // return false;
533 // }
534
535 // self.messages
536 // .get(ix + 1)
537 // .and_then(|message| {
538 // self.message(message.id)
539 // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
540 // })
541 // .unwrap_or(false)
542 }
543
544 pub fn tool_use_limit_reached(&self) -> bool {
545 self.tool_use_limit_reached
546 }
547
548 /// Returns whether any pending tool uses may perform edits
549 pub fn has_pending_edit_tool_uses(&self) -> bool {
550 todo!()
551 }
552
553 // pub fn insert_user_message(
554 // &mut self,
555 // text: impl Into<String>,
556 // loaded_context: ContextLoadResult,
557 // git_checkpoint: Option<GitStoreCheckpoint>,
558 // creases: Vec<MessageCrease>,
559 // cx: &mut Context<Self>,
560 // ) -> AgentThreadMessageId {
561 // todo!("move this logic into send")
562 // if !loaded_context.referenced_buffers.is_empty() {
563 // self.action_log.update(cx, |log, cx| {
564 // for buffer in loaded_context.referenced_buffers {
565 // log.buffer_read(buffer, cx);
566 // }
567 // });
568 // }
569
570 // let message_id = self.insert_message(
571 // Role::User,
572 // vec![MessageSegment::Text(text.into())],
573 // loaded_context.loaded_context,
574 // creases,
575 // false,
576 // cx,
577 // );
578
579 // if let Some(git_checkpoint) = git_checkpoint {
580 // self.pending_checkpoint = Some(ThreadCheckpoint {
581 // message_id,
582 // git_checkpoint,
583 // });
584 // }
585
586 // self.auto_capture_telemetry(cx);
587
588 // message_id
589 // }
590
591 pub fn set_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
592 todo!()
593 }
594
595 pub fn model(&self) -> Option<ConfiguredModel> {
596 todo!()
597 }
598
599 pub fn send(
600 &mut self,
601 message: Vec<AgentThreadUserMessageChunk>,
602 window: &mut Window,
603 cx: &mut Context<Self>,
604 ) {
605 todo!()
606 }
607
608 pub fn resume(&mut self, window: &mut Window, cx: &mut Context<Self>) {
609 todo!()
610 }
611
612 pub fn edit(
613 &mut self,
614 message_id: MessageId,
615 message: Vec<AgentThreadUserMessageChunk>,
616 window: &mut Window,
617 cx: &mut Context<Self>,
618 ) {
619 todo!()
620 }
621
622 pub fn cancel(&mut self, window: &mut Window, cx: &mut Context<Self>) -> bool {
623 todo!()
624 }
625
626 // pub fn insert_invisible_continue_message(
627 // &mut self,
628 // cx: &mut Context<Self>,
629 // ) -> AgentThreadMessageId {
630 // let id = self.insert_message(
631 // Role::User,
632 // vec![MessageSegment::Text("Continue where you left off".into())],
633 // LoadedContext::default(),
634 // vec![],
635 // true,
636 // cx,
637 // );
638 // self.pending_checkpoint = None;
639
640 // id
641 // }
642
643 // pub fn insert_assistant_message(
644 // &mut self,
645 // segments: Vec<MessageSegment>,
646 // cx: &mut Context<Self>,
647 // ) -> AgentThreadMessageId {
648 // self.insert_message(
649 // Role::Assistant,
650 // segments,
651 // LoadedContext::default(),
652 // Vec::new(),
653 // false,
654 // cx,
655 // )
656 // }
657
658 // pub fn insert_message(
659 // &mut self,
660 // role: Role,
661 // segments: Vec<MessageSegment>,
662 // loaded_context: LoadedContext,
663 // creases: Vec<MessageCrease>,
664 // is_hidden: bool,
665 // cx: &mut Context<Self>,
666 // ) -> AgentThreadMessageId {
667 // let id = self.next_message_id.post_inc();
668 // self.messages.push(Message {
669 // id,
670 // role,
671 // segments,
672 // loaded_context,
673 // creases,
674 // is_hidden,
675 // ui_only: false,
676 // });
677 // self.touch_updated_at();
678 // cx.emit(ThreadEvent::MessageAdded(id));
679 // id
680 // }
681
682 // pub fn edit_message(
683 // &mut self,
684 // id: AgentThreadMessageId,
685 // new_role: Role,
686 // new_segments: Vec<MessageSegment>,
687 // creases: Vec<MessageCrease>,
688 // loaded_context: Option<LoadedContext>,
689 // checkpoint: Option<GitStoreCheckpoint>,
690 // cx: &mut Context<Self>,
691 // ) -> bool {
692 // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
693 // return false;
694 // };
695 // message.role = new_role;
696 // message.segments = new_segments;
697 // message.creases = creases;
698 // if let Some(context) = loaded_context {
699 // message.loaded_context = context;
700 // }
701 // if let Some(git_checkpoint) = checkpoint {
702 // self.checkpoints_by_message.insert(
703 // id,
704 // ThreadCheckpoint {
705 // message_id: id,
706 // git_checkpoint,
707 // },
708 // );
709 // }
710 // self.touch_updated_at();
711 // cx.emit(ThreadEvent::MessageEdited(id));
712 // true
713 // }
714
715 /// Returns the representation of this [`Thread`] in a textual form.
716 ///
717 /// This is the representation we use when attaching a thread as context to another thread.
718 pub fn text(&self, cx: &App) -> String {
719 let mut text = String::new();
720
721 for message in &self.messages {
722 text.push_str(match message.role {
723 language_model::Role::User => "User:",
724 language_model::Role::Assistant => "Agent:",
725 language_model::Role::System => "System:",
726 });
727 text.push('\n');
728
729 text.push_str("<think>");
730 text.push_str(message.thinking.read(cx).source());
731 text.push_str("</think>");
732 text.push_str(message.text.read(cx).source());
733
734 // todo!('what about tools?');
735
736 text.push('\n');
737 }
738
739 text
740 }
741
742 pub fn used_tools_since_last_user_message(&self) -> bool {
743 todo!()
744 // for message in self.messages.iter().rev() {
745 // if self.tool_use.message_has_tool_results(message.id) {
746 // return true;
747 // } else if message.role == Role::User {
748 // return false;
749 // }
750 // }
751
752 // false
753 }
754
755 pub fn start_generating_detailed_summary_if_needed(
756 &mut self,
757 thread_store: WeakEntity<ThreadStore>,
758 cx: &mut Context<Self>,
759 ) {
760 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
761 return;
762 };
763
764 match &*self.detailed_summary_rx.borrow() {
765 DetailedSummaryState::Generating { message_id, .. }
766 | DetailedSummaryState::Generated { message_id, .. }
767 if *message_id == last_message_id =>
768 {
769 // Already up-to-date
770 return;
771 }
772 _ => {}
773 }
774
775 let summary = self.agent_thread.summary();
776
777 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
778 message_id: last_message_id,
779 };
780
781 // Replace the detailed summarization task if there is one, cancelling it. It would probably
782 // be better to allow the old task to complete, but this would require logic for choosing
783 // which result to prefer (the old task could complete after the new one, resulting in a
784 // stale summary).
785 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
786 let Some(summary) = summary.await.log_err() else {
787 thread
788 .update(cx, |thread, _cx| {
789 *thread.detailed_summary_tx.borrow_mut() =
790 DetailedSummaryState::NotGenerated;
791 })
792 .ok()?;
793 return None;
794 };
795
796 thread
797 .update(cx, |thread, _cx| {
798 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
799 text: summary.into(),
800 message_id: last_message_id,
801 };
802 })
803 .ok()?;
804
805 Some(())
806 });
807 }
808
809 pub async fn wait_for_detailed_summary_or_text(
810 this: &Entity<Self>,
811 cx: &mut AsyncApp,
812 ) -> Option<SharedString> {
813 let mut detailed_summary_rx = this
814 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
815 .ok()?;
816 loop {
817 match detailed_summary_rx.recv().await? {
818 DetailedSummaryState::Generating { .. } => {}
819 DetailedSummaryState::NotGenerated => {
820 return this.read_with(cx, |this, cx| this.text(cx).into()).ok();
821 }
822 DetailedSummaryState::Generated { text, .. } => return Some(text),
823 }
824 }
825 }
826
827 pub fn latest_detailed_summary_or_text(&self, cx: &App) -> SharedString {
828 self.detailed_summary_rx
829 .borrow()
830 .text()
831 .unwrap_or_else(|| self.text(cx).into())
832 }
833
834 pub fn is_generating_detailed_summary(&self) -> bool {
835 matches!(
836 &*self.detailed_summary_rx.borrow(),
837 DetailedSummaryState::Generating { .. }
838 )
839 }
840
841 pub fn feedback(&self) -> Option<ThreadFeedback> {
842 self.feedback
843 }
844
845 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
846 self.message_feedback.get(&message_id).copied()
847 }
848
849 pub fn report_message_feedback(
850 &mut self,
851 message_id: MessageId,
852 feedback: ThreadFeedback,
853 cx: &mut Context<Self>,
854 ) -> Task<Result<()>> {
855 todo!()
856 // if self.message_feedback.get(&message_id) == Some(&feedback) {
857 // return Task::ready(Ok(()));
858 // }
859
860 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
861 // let serialized_thread = self.serialize(cx);
862 // let thread_id = self.id().clone();
863 // let client = self.project.read(cx).client();
864
865 // let enabled_tool_names: Vec<String> = self
866 // .profile
867 // .enabled_tools(cx)
868 // .iter()
869 // .map(|tool| tool.name())
870 // .collect();
871
872 // self.message_feedback.insert(message_id, feedback);
873
874 // cx.notify();
875
876 // let message_content = self
877 // .message(message_id)
878 // .map(|msg| msg.to_string())
879 // .unwrap_or_default();
880
881 // cx.background_spawn(async move {
882 // let final_project_snapshot = final_project_snapshot.await;
883 // let serialized_thread = serialized_thread.await?;
884 // let thread_data =
885 // serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
886
887 // let rating = match feedback {
888 // ThreadFeedback::Positive => "positive",
889 // ThreadFeedback::Negative => "negative",
890 // };
891 // telemetry::event!(
892 // "Assistant Thread Rated",
893 // rating,
894 // thread_id,
895 // enabled_tool_names,
896 // message_id = message_id,
897 // message_content,
898 // thread_data,
899 // final_project_snapshot
900 // );
901 // client.telemetry().flush_events().await;
902
903 // Ok(())
904 // })
905 }
906
907 pub fn report_feedback(
908 &mut self,
909 feedback: ThreadFeedback,
910 cx: &mut Context<Self>,
911 ) -> Task<Result<()>> {
912 todo!()
913 // let last_assistant_message_id = self
914 // .messages
915 // .iter()
916 // .rev()
917 // .find(|msg| msg.role == Role::Assistant)
918 // .map(|msg| msg.id);
919
920 // if let Some(message_id) = last_assistant_message_id {
921 // self.report_message_feedback(message_id, feedback, cx)
922 // } else {
923 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
924 // let serialized_thread = self.serialize(cx);
925 // let thread_id = self.id().clone();
926 // let client = self.project.read(cx).client();
927 // self.feedback = Some(feedback);
928 // cx.notify();
929
930 // cx.background_spawn(async move {
931 // let final_project_snapshot = final_project_snapshot.await;
932 // let serialized_thread = serialized_thread.await?;
933 // let thread_data = serde_json::to_value(serialized_thread)
934 // .unwrap_or_else(|_| serde_json::Value::Null);
935
936 // let rating = match feedback {
937 // ThreadFeedback::Positive => "positive",
938 // ThreadFeedback::Negative => "negative",
939 // };
940 // telemetry::event!(
941 // "Assistant Thread Rated",
942 // rating,
943 // thread_id,
944 // thread_data,
945 // final_project_snapshot
946 // );
947 // client.telemetry().flush_events().await;
948
949 // Ok(())
950 // })
951 // }
952 }
953
954 /// Create a snapshot of the current project state including git information and unsaved buffers.
955 fn project_snapshot(
956 project: Entity<Project>,
957 cx: &mut Context<Self>,
958 ) -> Task<Arc<ProjectSnapshot>> {
959 let git_store = project.read(cx).git_store().clone();
960 let worktree_snapshots: Vec<_> = project
961 .read(cx)
962 .visible_worktrees(cx)
963 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
964 .collect();
965
966 cx.spawn(async move |_, cx| {
967 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
968
969 let mut unsaved_buffers = Vec::new();
970 cx.update(|app_cx| {
971 let buffer_store = project.read(app_cx).buffer_store();
972 for buffer_handle in buffer_store.read(app_cx).buffers() {
973 let buffer = buffer_handle.read(app_cx);
974 if buffer.is_dirty() {
975 if let Some(file) = buffer.file() {
976 let path = file.path().to_string_lossy().to_string();
977 unsaved_buffers.push(path);
978 }
979 }
980 }
981 })
982 .ok();
983
984 Arc::new(ProjectSnapshot {
985 worktree_snapshots,
986 unsaved_buffer_paths: unsaved_buffers,
987 timestamp: Utc::now(),
988 })
989 })
990 }
991
992 fn worktree_snapshot(
993 worktree: Entity<project::Worktree>,
994 git_store: Entity<GitStore>,
995 cx: &App,
996 ) -> Task<WorktreeSnapshot> {
997 cx.spawn(async move |cx| {
998 // Get worktree path and snapshot
999 let worktree_info = cx.update(|app_cx| {
1000 let worktree = worktree.read(app_cx);
1001 let path = worktree.abs_path().to_string_lossy().to_string();
1002 let snapshot = worktree.snapshot();
1003 (path, snapshot)
1004 });
1005
1006 let Ok((worktree_path, _snapshot)) = worktree_info else {
1007 return WorktreeSnapshot {
1008 worktree_path: String::new(),
1009 git_state: None,
1010 };
1011 };
1012
1013 let git_state = git_store
1014 .update(cx, |git_store, cx| {
1015 git_store
1016 .repositories()
1017 .values()
1018 .find(|repo| {
1019 repo.read(cx)
1020 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1021 .is_some()
1022 })
1023 .cloned()
1024 })
1025 .ok()
1026 .flatten()
1027 .map(|repo| {
1028 repo.update(cx, |repo, _| {
1029 let current_branch =
1030 repo.branch.as_ref().map(|branch| branch.name().to_owned());
1031 repo.send_job(None, |state, _| async move {
1032 let RepositoryState::Local { backend, .. } = state else {
1033 return GitState {
1034 remote_url: None,
1035 head_sha: None,
1036 current_branch,
1037 diff: None,
1038 };
1039 };
1040
1041 let remote_url = backend.remote_url("origin");
1042 let head_sha = backend.head_sha().await;
1043 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1044
1045 GitState {
1046 remote_url,
1047 head_sha,
1048 current_branch,
1049 diff,
1050 }
1051 })
1052 })
1053 });
1054
1055 let git_state = match git_state {
1056 Some(git_state) => match git_state.ok() {
1057 Some(git_state) => git_state.await.ok(),
1058 None => None,
1059 },
1060 None => None,
1061 };
1062
1063 WorktreeSnapshot {
1064 worktree_path,
1065 git_state,
1066 }
1067 })
1068 }
1069
1070 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1071 todo!()
1072 // let mut markdown = Vec::new();
1073
1074 // let summary = self.summary().or_default();
1075 // writeln!(markdown, "# {summary}\n")?;
1076
1077 // for message in self.messages() {
1078 // writeln!(
1079 // markdown,
1080 // "## {role}\n",
1081 // role = match message.role {
1082 // Role::User => "User",
1083 // Role::Assistant => "Agent",
1084 // Role::System => "System",
1085 // }
1086 // )?;
1087
1088 // if !message.loaded_context.text.is_empty() {
1089 // writeln!(markdown, "{}", message.loaded_context.text)?;
1090 // }
1091
1092 // if !message.loaded_context.images.is_empty() {
1093 // writeln!(
1094 // markdown,
1095 // "\n{} images attached as context.\n",
1096 // message.loaded_context.images.len()
1097 // )?;
1098 // }
1099
1100 // for segment in &message.segments {
1101 // match segment {
1102 // MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1103 // MessageSegment::Thinking { text, .. } => {
1104 // writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1105 // }
1106 // MessageSegment::RedactedThinking(_) => {}
1107 // }
1108 // }
1109
1110 // for tool_use in self.tool_uses_for_message(message.id, cx) {
1111 // writeln!(
1112 // markdown,
1113 // "**Use Tool: {} ({})**",
1114 // tool_use.name, tool_use.id
1115 // )?;
1116 // writeln!(markdown, "```json")?;
1117 // writeln!(
1118 // markdown,
1119 // "{}",
1120 // serde_json::to_string_pretty(&tool_use.input)?
1121 // )?;
1122 // writeln!(markdown, "```")?;
1123 // }
1124
1125 // for tool_result in self.tool_results_for_message(message.id) {
1126 // write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
1127 // if tool_result.is_error {
1128 // write!(markdown, " (Error)")?;
1129 // }
1130
1131 // writeln!(markdown, "**\n")?;
1132 // match &tool_result.content {
1133 // LanguageModelToolResultContent::Text(text) => {
1134 // writeln!(markdown, "{text}")?;
1135 // }
1136 // LanguageModelToolResultContent::Image(image) => {
1137 // writeln!(markdown, "", image.source)?;
1138 // }
1139 // }
1140
1141 // if let Some(output) = tool_result.output.as_ref() {
1142 // writeln!(
1143 // markdown,
1144 // "\n\nDebug Output:\n\n```json\n{}\n```\n",
1145 // serde_json::to_string_pretty(output)?
1146 // )?;
1147 // }
1148 // }
1149 // }
1150
1151 // Ok(String::from_utf8_lossy(&markdown).to_string())
1152 }
1153
1154 pub fn keep_edits_in_range(
1155 &mut self,
1156 buffer: Entity<language::Buffer>,
1157 buffer_range: Range<language::Anchor>,
1158 cx: &mut Context<Self>,
1159 ) {
1160 self.action_log.update(cx, |action_log, cx| {
1161 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1162 });
1163 }
1164
1165 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1166 self.action_log
1167 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1168 }
1169
1170 pub fn reject_edits_in_ranges(
1171 &mut self,
1172 buffer: Entity<language::Buffer>,
1173 buffer_ranges: Vec<Range<language::Anchor>>,
1174 cx: &mut Context<Self>,
1175 ) -> Task<Result<()>> {
1176 self.action_log.update(cx, |action_log, cx| {
1177 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1178 })
1179 }
1180
1181 pub fn action_log(&self) -> &Entity<ActionLog> {
1182 &self.action_log
1183 }
1184
1185 pub fn project(&self) -> &Entity<Project> {
1186 &self.project
1187 }
1188
1189 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1190 todo!()
1191 // if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
1192 // return;
1193 // }
1194
1195 // let now = Instant::now();
1196 // if let Some(last) = self.last_auto_capture_at {
1197 // if now.duration_since(last).as_secs() < 10 {
1198 // return;
1199 // }
1200 // }
1201
1202 // self.last_auto_capture_at = Some(now);
1203
1204 // let thread_id = self.id().clone();
1205 // let github_login = self
1206 // .project
1207 // .read(cx)
1208 // .user_store()
1209 // .read(cx)
1210 // .current_user()
1211 // .map(|user| user.github_login.clone());
1212 // let client = self.project.read(cx).client();
1213 // let serialize_task = self.serialize(cx);
1214
1215 // cx.background_executor()
1216 // .spawn(async move {
1217 // if let Ok(serialized_thread) = serialize_task.await {
1218 // if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1219 // telemetry::event!(
1220 // "Agent Thread Auto-Captured",
1221 // thread_id = thread_id.to_string(),
1222 // thread_data = thread_data,
1223 // auto_capture_reason = "tracked_user",
1224 // github_login = github_login
1225 // );
1226
1227 // client.telemetry().flush_events().await;
1228 // }
1229 // }
1230 // })
1231 // .detach();
1232 }
1233
1234 pub fn cumulative_token_usage(&self) -> TokenUsage {
1235 self.cumulative_token_usage
1236 }
1237
1238 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
1239 todo!()
1240 // let Some(model) = self.configured_model.as_ref() else {
1241 // return TotalTokenUsage::default();
1242 // };
1243
1244 // let max = model.model.max_token_count();
1245
1246 // let index = self
1247 // .messages
1248 // .iter()
1249 // .position(|msg| msg.id == message_id)
1250 // .unwrap_or(0);
1251
1252 // if index == 0 {
1253 // return TotalTokenUsage { total: 0, max };
1254 // }
1255
1256 // let token_usage = &self
1257 // .request_token_usage
1258 // .get(index - 1)
1259 // .cloned()
1260 // .unwrap_or_default();
1261
1262 // TotalTokenUsage {
1263 // total: token_usage.total_tokens(),
1264 // max,
1265 // }
1266 }
1267
1268 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
1269 todo!()
1270 // let model = self.configured_model.as_ref()?;
1271
1272 // let max = model.model.max_token_count();
1273
1274 // if let Some(exceeded_error) = &self.exceeded_window_error {
1275 // if model.model.id() == exceeded_error.model_id {
1276 // return Some(TotalTokenUsage {
1277 // total: exceeded_error.token_count,
1278 // max,
1279 // });
1280 // }
1281 // }
1282
1283 // let total = self
1284 // .token_usage_at_last_message()
1285 // .unwrap_or_default()
1286 // .total_tokens();
1287
1288 // Some(TotalTokenUsage { total, max })
1289 }
1290
1291 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1292 self.request_token_usage
1293 .get(self.messages.len().saturating_sub(1))
1294 .or_else(|| self.request_token_usage.last())
1295 .cloned()
1296 }
1297
1298 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
1299 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
1300 self.request_token_usage
1301 .resize(self.messages.len(), placeholder);
1302
1303 if let Some(last) = self.request_token_usage.last_mut() {
1304 *last = token_usage;
1305 }
1306 }
1307
1308 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
1309 self.project.update(cx, |project, cx| {
1310 project.user_store().update(cx, |user_store, cx| {
1311 user_store.update_model_request_usage(
1312 ModelRequestUsage(RequestUsage {
1313 amount: amount as i32,
1314 limit,
1315 }),
1316 cx,
1317 )
1318 })
1319 });
1320 }
1321}
1322
1323#[derive(Debug, Clone, Error)]
1324pub enum ThreadError {
1325 #[error("Payment required")]
1326 PaymentRequired,
1327 #[error("Model request limit reached")]
1328 ModelRequestLimitReached { plan: Plan },
1329 #[error("Message {header}: {message}")]
1330 Message {
1331 header: SharedString,
1332 message: SharedString,
1333 },
1334}
1335
1336#[derive(Debug, Clone)]
1337pub enum ThreadEvent {
1338 ShowError(ThreadError),
1339 StreamedCompletion,
1340 NewRequest,
1341 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1342 MessagesUpdated {
1343 old_range: Range<usize>,
1344 new_length: usize,
1345 },
1346 SummaryGenerated,
1347 SummaryChanged,
1348 CheckpointChanged,
1349 ToolConfirmationNeeded,
1350 ToolUseLimitReached,
1351 CancelEditing,
1352 CompletionCanceled,
1353 ProfileChanged,
1354 RetriesFailed {
1355 message: SharedString,
1356 },
1357}
1358
1359impl EventEmitter<ThreadEvent> for Thread {}
1360
1361/// Resolves tool name conflicts by ensuring all tool names are unique.
1362///
1363/// When multiple tools have the same name, this function applies the following rules:
1364/// 1. Native tools always keep their original name
1365/// 2. Context server tools get prefixed with their server ID and an underscore
1366/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
1367/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
1368///
1369/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
1370fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
1371 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
1372 let mut tool_name = tool.name();
1373 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1374 tool_name
1375 }
1376
1377 const MAX_TOOL_NAME_LENGTH: usize = 64;
1378
1379 let mut duplicated_tool_names = HashSet::default();
1380 let mut seen_tool_names = HashSet::default();
1381 for tool in tools {
1382 let tool_name = resolve_tool_name(tool);
1383 if seen_tool_names.contains(&tool_name) {
1384 debug_assert!(
1385 tool.source() != assistant_tool::ToolSource::Native,
1386 "There are two built-in tools with the same name: {}",
1387 tool_name
1388 );
1389 duplicated_tool_names.insert(tool_name);
1390 } else {
1391 seen_tool_names.insert(tool_name);
1392 }
1393 }
1394
1395 if duplicated_tool_names.is_empty() {
1396 return tools
1397 .into_iter()
1398 .map(|tool| (resolve_tool_name(tool), tool.clone()))
1399 .collect();
1400 }
1401
1402 tools
1403 .into_iter()
1404 .filter_map(|tool| {
1405 let mut tool_name = resolve_tool_name(tool);
1406 if !duplicated_tool_names.contains(&tool_name) {
1407 return Some((tool_name, tool.clone()));
1408 }
1409 match tool.source() {
1410 assistant_tool::ToolSource::Native => {
1411 // Built-in tools always keep their original name
1412 Some((tool_name, tool.clone()))
1413 }
1414 assistant_tool::ToolSource::ContextServer { id } => {
1415 // Context server tools are prefixed with the context server ID, and truncated if necessary
1416 tool_name.insert(0, '_');
1417 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
1418 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
1419 let mut id = id.to_string();
1420 id.truncate(len);
1421 tool_name.insert_str(0, &id);
1422 } else {
1423 tool_name.insert_str(0, &id);
1424 }
1425
1426 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1427
1428 if seen_tool_names.contains(&tool_name) {
1429 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
1430 None
1431 } else {
1432 Some((tool_name, tool.clone()))
1433 }
1434 }
1435 }
1436 })
1437 .collect()
1438}
1439
1440// #[cfg(test)]
1441// mod tests {
1442// use super::*;
1443// use crate::{
1444// context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
1445// };
1446
1447// // Test-specific constants
1448// const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
1449// use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
1450// use assistant_tool::ToolRegistry;
1451// use futures::StreamExt;
1452// use futures::future::BoxFuture;
1453// use futures::stream::BoxStream;
1454// use gpui::TestAppContext;
1455// use icons::IconName;
1456// use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1457// use language_model::{
1458// LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
1459// LanguageModelProviderName, LanguageModelToolChoice,
1460// };
1461// use parking_lot::Mutex;
1462// use project::{FakeFs, Project};
1463// use prompt_store::PromptBuilder;
1464// use serde_json::json;
1465// use settings::{Settings, SettingsStore};
1466// use std::sync::Arc;
1467// use std::time::Duration;
1468// use theme::ThemeSettings;
1469// use util::path;
1470// use workspace::Workspace;
1471
1472// #[gpui::test]
1473// async fn test_message_with_context(cx: &mut TestAppContext) {
1474// init_test_settings(cx);
1475
1476// let project = create_test_project(
1477// cx,
1478// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1479// )
1480// .await;
1481
1482// let (_workspace, _thread_store, thread, context_store, model) =
1483// setup_test_environment(cx, project.clone()).await;
1484
1485// add_file_to_context(&project, &context_store, "test/code.rs", cx)
1486// .await
1487// .unwrap();
1488
1489// let context =
1490// context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
1491// let loaded_context = cx
1492// .update(|cx| load_context(vec![context], &project, &None, cx))
1493// .await;
1494
1495// // Insert user message with context
1496// let message_id = thread.update(cx, |thread, cx| {
1497// thread.insert_user_message(
1498// "Please explain this code",
1499// loaded_context,
1500// None,
1501// Vec::new(),
1502// cx,
1503// )
1504// });
1505
1506// // Check content and context in message object
1507// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1508
1509// // Use different path format strings based on platform for the test
1510// #[cfg(windows)]
1511// let path_part = r"test\code.rs";
1512// #[cfg(not(windows))]
1513// let path_part = "test/code.rs";
1514
1515// let expected_context = format!(
1516// r#"
1517// <context>
1518// The following items were attached by the user. They are up-to-date and don't need to be re-read.
1519
1520// <files>
1521// ```rs {path_part}
1522// fn main() {{
1523// println!("Hello, world!");
1524// }}
1525// ```
1526// </files>
1527// </context>
1528// "#
1529// );
1530
1531// assert_eq!(message.role, Role::User);
1532// assert_eq!(message.segments.len(), 1);
1533// assert_eq!(
1534// message.segments[0],
1535// MessageSegment::Text("Please explain this code".to_string())
1536// );
1537// assert_eq!(message.loaded_context.text, expected_context);
1538
1539// // Check message in request
1540// let request = thread.update(cx, |thread, cx| {
1541// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1542// });
1543
1544// assert_eq!(request.messages.len(), 2);
1545// let expected_full_message = format!("{}Please explain this code", expected_context);
1546// assert_eq!(request.messages[1].string_contents(), expected_full_message);
1547// }
1548
1549// #[gpui::test]
1550// async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
1551// init_test_settings(cx);
1552
1553// let project = create_test_project(
1554// cx,
1555// json!({
1556// "file1.rs": "fn function1() {}\n",
1557// "file2.rs": "fn function2() {}\n",
1558// "file3.rs": "fn function3() {}\n",
1559// "file4.rs": "fn function4() {}\n",
1560// }),
1561// )
1562// .await;
1563
1564// let (_, _thread_store, thread, context_store, model) =
1565// setup_test_environment(cx, project.clone()).await;
1566
1567// // First message with context 1
1568// add_file_to_context(&project, &context_store, "test/file1.rs", cx)
1569// .await
1570// .unwrap();
1571// let new_contexts = context_store.update(cx, |store, cx| {
1572// store.new_context_for_thread(thread.read(cx), None)
1573// });
1574// assert_eq!(new_contexts.len(), 1);
1575// let loaded_context = cx
1576// .update(|cx| load_context(new_contexts, &project, &None, cx))
1577// .await;
1578// let message1_id = thread.update(cx, |thread, cx| {
1579// thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
1580// });
1581
1582// // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
1583// add_file_to_context(&project, &context_store, "test/file2.rs", cx)
1584// .await
1585// .unwrap();
1586// let new_contexts = context_store.update(cx, |store, cx| {
1587// store.new_context_for_thread(thread.read(cx), None)
1588// });
1589// assert_eq!(new_contexts.len(), 1);
1590// let loaded_context = cx
1591// .update(|cx| load_context(new_contexts, &project, &None, cx))
1592// .await;
1593// let message2_id = thread.update(cx, |thread, cx| {
1594// thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
1595// });
1596
1597// // Third message with all three contexts (contexts 1 and 2 should be skipped)
1598// //
1599// add_file_to_context(&project, &context_store, "test/file3.rs", cx)
1600// .await
1601// .unwrap();
1602// let new_contexts = context_store.update(cx, |store, cx| {
1603// store.new_context_for_thread(thread.read(cx), None)
1604// });
1605// assert_eq!(new_contexts.len(), 1);
1606// let loaded_context = cx
1607// .update(|cx| load_context(new_contexts, &project, &None, cx))
1608// .await;
1609// let message3_id = thread.update(cx, |thread, cx| {
1610// thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
1611// });
1612
1613// // Check what contexts are included in each message
1614// let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
1615// (
1616// thread.message(message1_id).unwrap().clone(),
1617// thread.message(message2_id).unwrap().clone(),
1618// thread.message(message3_id).unwrap().clone(),
1619// )
1620// });
1621
1622// // First message should include context 1
1623// assert!(message1.loaded_context.text.contains("file1.rs"));
1624
1625// // Second message should include only context 2 (not 1)
1626// assert!(!message2.loaded_context.text.contains("file1.rs"));
1627// assert!(message2.loaded_context.text.contains("file2.rs"));
1628
1629// // Third message should include only context 3 (not 1 or 2)
1630// assert!(!message3.loaded_context.text.contains("file1.rs"));
1631// assert!(!message3.loaded_context.text.contains("file2.rs"));
1632// assert!(message3.loaded_context.text.contains("file3.rs"));
1633
1634// // Check entire request to make sure all contexts are properly included
1635// let request = thread.update(cx, |thread, cx| {
1636// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1637// });
1638
1639// // The request should contain all 3 messages
1640// assert_eq!(request.messages.len(), 4);
1641
1642// // Check that the contexts are properly formatted in each message
1643// assert!(request.messages[1].string_contents().contains("file1.rs"));
1644// assert!(!request.messages[1].string_contents().contains("file2.rs"));
1645// assert!(!request.messages[1].string_contents().contains("file3.rs"));
1646
1647// assert!(!request.messages[2].string_contents().contains("file1.rs"));
1648// assert!(request.messages[2].string_contents().contains("file2.rs"));
1649// assert!(!request.messages[2].string_contents().contains("file3.rs"));
1650
1651// assert!(!request.messages[3].string_contents().contains("file1.rs"));
1652// assert!(!request.messages[3].string_contents().contains("file2.rs"));
1653// assert!(request.messages[3].string_contents().contains("file3.rs"));
1654
1655// add_file_to_context(&project, &context_store, "test/file4.rs", cx)
1656// .await
1657// .unwrap();
1658// let new_contexts = context_store.update(cx, |store, cx| {
1659// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1660// });
1661// assert_eq!(new_contexts.len(), 3);
1662// let loaded_context = cx
1663// .update(|cx| load_context(new_contexts, &project, &None, cx))
1664// .await
1665// .loaded_context;
1666
1667// assert!(!loaded_context.text.contains("file1.rs"));
1668// assert!(loaded_context.text.contains("file2.rs"));
1669// assert!(loaded_context.text.contains("file3.rs"));
1670// assert!(loaded_context.text.contains("file4.rs"));
1671
1672// let new_contexts = context_store.update(cx, |store, cx| {
1673// // Remove file4.rs
1674// store.remove_context(&loaded_context.contexts[2].handle(), cx);
1675// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1676// });
1677// assert_eq!(new_contexts.len(), 2);
1678// let loaded_context = cx
1679// .update(|cx| load_context(new_contexts, &project, &None, cx))
1680// .await
1681// .loaded_context;
1682
1683// assert!(!loaded_context.text.contains("file1.rs"));
1684// assert!(loaded_context.text.contains("file2.rs"));
1685// assert!(loaded_context.text.contains("file3.rs"));
1686// assert!(!loaded_context.text.contains("file4.rs"));
1687
1688// let new_contexts = context_store.update(cx, |store, cx| {
1689// // Remove file3.rs
1690// store.remove_context(&loaded_context.contexts[1].handle(), cx);
1691// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1692// });
1693// assert_eq!(new_contexts.len(), 1);
1694// let loaded_context = cx
1695// .update(|cx| load_context(new_contexts, &project, &None, cx))
1696// .await
1697// .loaded_context;
1698
1699// assert!(!loaded_context.text.contains("file1.rs"));
1700// assert!(loaded_context.text.contains("file2.rs"));
1701// assert!(!loaded_context.text.contains("file3.rs"));
1702// assert!(!loaded_context.text.contains("file4.rs"));
1703// }
1704
1705// #[gpui::test]
1706// async fn test_message_without_files(cx: &mut TestAppContext) {
1707// init_test_settings(cx);
1708
1709// let project = create_test_project(
1710// cx,
1711// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1712// )
1713// .await;
1714
1715// let (_, _thread_store, thread, _context_store, model) =
1716// setup_test_environment(cx, project.clone()).await;
1717
1718// // Insert user message without any context (empty context vector)
1719// let message_id = thread.update(cx, |thread, cx| {
1720// thread.insert_user_message(
1721// "What is the best way to learn Rust?",
1722// ContextLoadResult::default(),
1723// None,
1724// Vec::new(),
1725// cx,
1726// )
1727// });
1728
1729// // Check content and context in message object
1730// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1731
1732// // Context should be empty when no files are included
1733// assert_eq!(message.role, Role::User);
1734// assert_eq!(message.segments.len(), 1);
1735// assert_eq!(
1736// message.segments[0],
1737// MessageSegment::Text("What is the best way to learn Rust?".to_string())
1738// );
1739// assert_eq!(message.loaded_context.text, "");
1740
1741// // Check message in request
1742// let request = thread.update(cx, |thread, cx| {
1743// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1744// });
1745
1746// assert_eq!(request.messages.len(), 2);
1747// assert_eq!(
1748// request.messages[1].string_contents(),
1749// "What is the best way to learn Rust?"
1750// );
1751
1752// // Add second message, also without context
1753// let message2_id = thread.update(cx, |thread, cx| {
1754// thread.insert_user_message(
1755// "Are there any good books?",
1756// ContextLoadResult::default(),
1757// None,
1758// Vec::new(),
1759// cx,
1760// )
1761// });
1762
1763// let message2 =
1764// thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
1765// assert_eq!(message2.loaded_context.text, "");
1766
1767// // Check that both messages appear in the request
1768// let request = thread.update(cx, |thread, cx| {
1769// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1770// });
1771
1772// assert_eq!(request.messages.len(), 3);
1773// assert_eq!(
1774// request.messages[1].string_contents(),
1775// "What is the best way to learn Rust?"
1776// );
1777// assert_eq!(
1778// request.messages[2].string_contents(),
1779// "Are there any good books?"
1780// );
1781// }
1782
1783// #[gpui::test]
1784// async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
1785// init_test_settings(cx);
1786
1787// let project = create_test_project(
1788// cx,
1789// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1790// )
1791// .await;
1792
1793// let (_workspace, thread_store, thread, _context_store, _model) =
1794// setup_test_environment(cx, project.clone()).await;
1795
1796// // Check that we are starting with the default profile
1797// let profile = cx.read(|cx| thread.read(cx).profile.clone());
1798// let tool_set = cx.read(|cx| thread_store.read(cx).tools());
1799// assert_eq!(
1800// profile,
1801// AgentProfile::new(AgentProfileId::default(), tool_set)
1802// );
1803// }
1804
1805// #[gpui::test]
1806// async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
1807// init_test_settings(cx);
1808
1809// let project = create_test_project(
1810// cx,
1811// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1812// )
1813// .await;
1814
1815// let (_workspace, thread_store, thread, _context_store, _model) =
1816// setup_test_environment(cx, project.clone()).await;
1817
1818// // Profile gets serialized with default values
1819// let serialized = thread
1820// .update(cx, |thread, cx| thread.serialize(cx))
1821// .await
1822// .unwrap();
1823
1824// assert_eq!(serialized.profile, Some(AgentProfileId::default()));
1825
1826// let deserialized = cx.update(|cx| {
1827// thread.update(cx, |thread, cx| {
1828// Thread::deserialize(
1829// thread.id.clone(),
1830// serialized,
1831// thread.project.clone(),
1832// thread.tools.clone(),
1833// thread.prompt_builder.clone(),
1834// thread.project_context.clone(),
1835// None,
1836// cx,
1837// )
1838// })
1839// });
1840// let tool_set = cx.read(|cx| thread_store.read(cx).tools());
1841
1842// assert_eq!(
1843// deserialized.profile,
1844// AgentProfile::new(AgentProfileId::default(), tool_set)
1845// );
1846// }
1847
1848// #[gpui::test]
1849// async fn test_temperature_setting(cx: &mut TestAppContext) {
1850// init_test_settings(cx);
1851
1852// let project = create_test_project(
1853// cx,
1854// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1855// )
1856// .await;
1857
1858// let (_workspace, _thread_store, thread, _context_store, model) =
1859// setup_test_environment(cx, project.clone()).await;
1860
1861// // Both model and provider
1862// cx.update(|cx| {
1863// AgentSettings::override_global(
1864// AgentSettings {
1865// model_parameters: vec![LanguageModelParameters {
1866// provider: Some(model.provider_id().0.to_string().into()),
1867// model: Some(model.id().0.clone()),
1868// temperature: Some(0.66),
1869// }],
1870// ..AgentSettings::get_global(cx).clone()
1871// },
1872// cx,
1873// );
1874// });
1875
1876// let request = thread.update(cx, |thread, cx| {
1877// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1878// });
1879// assert_eq!(request.temperature, Some(0.66));
1880
1881// // Only model
1882// cx.update(|cx| {
1883// AgentSettings::override_global(
1884// AgentSettings {
1885// model_parameters: vec![LanguageModelParameters {
1886// provider: None,
1887// model: Some(model.id().0.clone()),
1888// temperature: Some(0.66),
1889// }],
1890// ..AgentSettings::get_global(cx).clone()
1891// },
1892// cx,
1893// );
1894// });
1895
1896// let request = thread.update(cx, |thread, cx| {
1897// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1898// });
1899// assert_eq!(request.temperature, Some(0.66));
1900
1901// // Only provider
1902// cx.update(|cx| {
1903// AgentSettings::override_global(
1904// AgentSettings {
1905// model_parameters: vec![LanguageModelParameters {
1906// provider: Some(model.provider_id().0.to_string().into()),
1907// model: None,
1908// temperature: Some(0.66),
1909// }],
1910// ..AgentSettings::get_global(cx).clone()
1911// },
1912// cx,
1913// );
1914// });
1915
1916// let request = thread.update(cx, |thread, cx| {
1917// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1918// });
1919// assert_eq!(request.temperature, Some(0.66));
1920
1921// // Same model name, different provider
1922// cx.update(|cx| {
1923// AgentSettings::override_global(
1924// AgentSettings {
1925// model_parameters: vec![LanguageModelParameters {
1926// provider: Some("anthropic".into()),
1927// model: Some(model.id().0.clone()),
1928// temperature: Some(0.66),
1929// }],
1930// ..AgentSettings::get_global(cx).clone()
1931// },
1932// cx,
1933// );
1934// });
1935
1936// let request = thread.update(cx, |thread, cx| {
1937// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1938// });
1939// assert_eq!(request.temperature, None);
1940// }
1941
1942// #[gpui::test]
1943// async fn test_thread_summary(cx: &mut TestAppContext) {
1944// init_test_settings(cx);
1945
1946// let project = create_test_project(cx, json!({})).await;
1947
1948// let (_, _thread_store, thread, _context_store, model) =
1949// setup_test_environment(cx, project.clone()).await;
1950
1951// // Initial state should be pending
1952// thread.read_with(cx, |thread, _| {
1953// assert!(matches!(thread.summary(), ThreadSummary::Pending));
1954// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
1955// });
1956
1957// // Manually setting the summary should not be allowed in this state
1958// thread.update(cx, |thread, cx| {
1959// thread.set_summary("This should not work", cx);
1960// });
1961
1962// thread.read_with(cx, |thread, _| {
1963// assert!(matches!(thread.summary(), ThreadSummary::Pending));
1964// });
1965
1966// // Send a message
1967// thread.update(cx, |thread, cx| {
1968// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
1969// thread.send_to_model(
1970// model.clone(),
1971// CompletionIntent::ThreadSummarization,
1972// None,
1973// cx,
1974// );
1975// });
1976
1977// let fake_model = model.as_fake();
1978// simulate_successful_response(&fake_model, cx);
1979
1980// // Should start generating summary when there are >= 2 messages
1981// thread.read_with(cx, |thread, _| {
1982// assert_eq!(*thread.summary(), ThreadSummary::Generating);
1983// });
1984
1985// // Should not be able to set the summary while generating
1986// thread.update(cx, |thread, cx| {
1987// thread.set_summary("This should not work either", cx);
1988// });
1989
1990// thread.read_with(cx, |thread, _| {
1991// assert!(matches!(thread.summary(), ThreadSummary::Generating));
1992// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
1993// });
1994
1995// cx.run_until_parked();
1996// fake_model.stream_last_completion_response("Brief");
1997// fake_model.stream_last_completion_response(" Introduction");
1998// fake_model.end_last_completion_stream();
1999// cx.run_until_parked();
2000
2001// // Summary should be set
2002// thread.read_with(cx, |thread, _| {
2003// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2004// assert_eq!(thread.summary().or_default(), "Brief Introduction");
2005// });
2006
2007// // Now we should be able to set a summary
2008// thread.update(cx, |thread, cx| {
2009// thread.set_summary("Brief Intro", cx);
2010// });
2011
2012// thread.read_with(cx, |thread, _| {
2013// assert_eq!(thread.summary().or_default(), "Brief Intro");
2014// });
2015
2016// // Test setting an empty summary (should default to DEFAULT)
2017// thread.update(cx, |thread, cx| {
2018// thread.set_summary("", cx);
2019// });
2020
2021// thread.read_with(cx, |thread, _| {
2022// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2023// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
2024// });
2025// }
2026
2027// #[gpui::test]
2028// async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
2029// init_test_settings(cx);
2030
2031// let project = create_test_project(cx, json!({})).await;
2032
2033// let (_, _thread_store, thread, _context_store, model) =
2034// setup_test_environment(cx, project.clone()).await;
2035
2036// test_summarize_error(&model, &thread, cx);
2037
2038// // Now we should be able to set a summary
2039// thread.update(cx, |thread, cx| {
2040// thread.set_summary("Brief Intro", cx);
2041// });
2042
2043// thread.read_with(cx, |thread, _| {
2044// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2045// assert_eq!(thread.summary().or_default(), "Brief Intro");
2046// });
2047// }
2048
2049// #[gpui::test]
2050// async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
2051// init_test_settings(cx);
2052
2053// let project = create_test_project(cx, json!({})).await;
2054
2055// let (_, _thread_store, thread, _context_store, model) =
2056// setup_test_environment(cx, project.clone()).await;
2057
2058// test_summarize_error(&model, &thread, cx);
2059
2060// // Sending another message should not trigger another summarize request
2061// thread.update(cx, |thread, cx| {
2062// thread.insert_user_message(
2063// "How are you?",
2064// ContextLoadResult::default(),
2065// None,
2066// vec![],
2067// cx,
2068// );
2069// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2070// });
2071
2072// let fake_model = model.as_fake();
2073// simulate_successful_response(&fake_model, cx);
2074
2075// thread.read_with(cx, |thread, _| {
2076// // State is still Error, not Generating
2077// assert!(matches!(thread.summary(), ThreadSummary::Error));
2078// });
2079
2080// // But the summarize request can be invoked manually
2081// thread.update(cx, |thread, cx| {
2082// thread.summarize(cx);
2083// });
2084
2085// thread.read_with(cx, |thread, _| {
2086// assert!(matches!(thread.summary(), ThreadSummary::Generating));
2087// });
2088
2089// cx.run_until_parked();
2090// fake_model.stream_last_completion_response("A successful summary");
2091// fake_model.end_last_completion_stream();
2092// cx.run_until_parked();
2093
2094// thread.read_with(cx, |thread, _| {
2095// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2096// assert_eq!(thread.summary().or_default(), "A successful summary");
2097// });
2098// }
2099
2100// #[gpui::test]
2101// fn test_resolve_tool_name_conflicts() {
2102// use assistant_tool::{Tool, ToolSource};
2103
2104// assert_resolve_tool_name_conflicts(
2105// vec![
2106// TestTool::new("tool1", ToolSource::Native),
2107// TestTool::new("tool2", ToolSource::Native),
2108// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2109// ],
2110// vec!["tool1", "tool2", "tool3"],
2111// );
2112
2113// assert_resolve_tool_name_conflicts(
2114// vec![
2115// TestTool::new("tool1", ToolSource::Native),
2116// TestTool::new("tool2", ToolSource::Native),
2117// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2118// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
2119// ],
2120// vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
2121// );
2122
2123// assert_resolve_tool_name_conflicts(
2124// vec![
2125// TestTool::new("tool1", ToolSource::Native),
2126// TestTool::new("tool2", ToolSource::Native),
2127// TestTool::new("tool3", ToolSource::Native),
2128// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2129// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
2130// ],
2131// vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
2132// );
2133
2134// // Test that tool with very long name is always truncated
2135// assert_resolve_tool_name_conflicts(
2136// vec![TestTool::new(
2137// "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
2138// ToolSource::Native,
2139// )],
2140// vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
2141// );
2142
2143// // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
2144// assert_resolve_tool_name_conflicts(
2145// vec![
2146// TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
2147// TestTool::new(
2148// "tool-with-very-very-very-long-name",
2149// ToolSource::ContextServer {
2150// id: "mcp-with-very-very-very-long-name".into(),
2151// },
2152// ),
2153// ],
2154// vec![
2155// "tool-with-very-very-very-long-name",
2156// "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
2157// ],
2158// );
2159
2160// fn assert_resolve_tool_name_conflicts(
2161// tools: Vec<TestTool>,
2162// expected: Vec<impl Into<String>>,
2163// ) {
2164// let tools: Vec<Arc<dyn Tool>> = tools
2165// .into_iter()
2166// .map(|t| Arc::new(t) as Arc<dyn Tool>)
2167// .collect();
2168// let tools = resolve_tool_name_conflicts(&tools);
2169// assert_eq!(tools.len(), expected.len());
2170// for (i, expected_name) in expected.into_iter().enumerate() {
2171// let expected_name = expected_name.into();
2172// let actual_name = &tools[i].0;
2173// assert_eq!(
2174// actual_name, &expected_name,
2175// "Expected '{}' got '{}' at index {}",
2176// expected_name, actual_name, i
2177// );
2178// }
2179// }
2180
2181// struct TestTool {
2182// name: String,
2183// source: ToolSource,
2184// }
2185
2186// impl TestTool {
2187// fn new(name: impl Into<String>, source: ToolSource) -> Self {
2188// Self {
2189// name: name.into(),
2190// source,
2191// }
2192// }
2193// }
2194
2195// impl Tool for TestTool {
2196// fn name(&self) -> String {
2197// self.name.clone()
2198// }
2199
2200// fn icon(&self) -> IconName {
2201// IconName::Ai
2202// }
2203
2204// fn may_perform_edits(&self) -> bool {
2205// false
2206// }
2207
2208// fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
2209// true
2210// }
2211
2212// fn source(&self) -> ToolSource {
2213// self.source.clone()
2214// }
2215
2216// fn description(&self) -> String {
2217// "Test tool".to_string()
2218// }
2219
2220// fn ui_text(&self, _input: &serde_json::Value) -> String {
2221// "Test tool".to_string()
2222// }
2223
2224// fn run(
2225// self: Arc<Self>,
2226// _input: serde_json::Value,
2227// _request: Arc<LanguageModelRequest>,
2228// _project: Entity<Project>,
2229// _action_log: Entity<ActionLog>,
2230// _model: Arc<dyn LanguageModel>,
2231// _window: Option<AnyWindowHandle>,
2232// _cx: &mut App,
2233// ) -> assistant_tool::ToolResult {
2234// assistant_tool::ToolResult {
2235// output: Task::ready(Err(anyhow::anyhow!("No content"))),
2236// card: None,
2237// }
2238// }
2239// }
2240// }
2241
2242// // Helper to create a model that returns errors
2243// enum TestError {
2244// Overloaded,
2245// InternalServerError,
2246// }
2247
2248// struct ErrorInjector {
2249// inner: Arc<FakeLanguageModel>,
2250// error_type: TestError,
2251// }
2252
2253// impl ErrorInjector {
2254// fn new(error_type: TestError) -> Self {
2255// Self {
2256// inner: Arc::new(FakeLanguageModel::default()),
2257// error_type,
2258// }
2259// }
2260// }
2261
2262// impl LanguageModel for ErrorInjector {
2263// fn id(&self) -> LanguageModelId {
2264// self.inner.id()
2265// }
2266
2267// fn name(&self) -> LanguageModelName {
2268// self.inner.name()
2269// }
2270
2271// fn provider_id(&self) -> LanguageModelProviderId {
2272// self.inner.provider_id()
2273// }
2274
2275// fn provider_name(&self) -> LanguageModelProviderName {
2276// self.inner.provider_name()
2277// }
2278
2279// fn supports_tools(&self) -> bool {
2280// self.inner.supports_tools()
2281// }
2282
2283// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2284// self.inner.supports_tool_choice(choice)
2285// }
2286
2287// fn supports_images(&self) -> bool {
2288// self.inner.supports_images()
2289// }
2290
2291// fn telemetry_id(&self) -> String {
2292// self.inner.telemetry_id()
2293// }
2294
2295// fn max_token_count(&self) -> u64 {
2296// self.inner.max_token_count()
2297// }
2298
2299// fn count_tokens(
2300// &self,
2301// request: LanguageModelRequest,
2302// cx: &App,
2303// ) -> BoxFuture<'static, Result<u64>> {
2304// self.inner.count_tokens(request, cx)
2305// }
2306
2307// fn stream_completion(
2308// &self,
2309// _request: LanguageModelRequest,
2310// _cx: &AsyncApp,
2311// ) -> BoxFuture<
2312// 'static,
2313// Result<
2314// BoxStream<
2315// 'static,
2316// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
2317// >,
2318// LanguageModelCompletionError,
2319// >,
2320// > {
2321// let error = match self.error_type {
2322// TestError::Overloaded => LanguageModelCompletionError::Overloaded,
2323// TestError::InternalServerError => {
2324// LanguageModelCompletionError::ApiInternalServerError
2325// }
2326// };
2327// async move {
2328// let stream = futures::stream::once(async move { Err(error) });
2329// Ok(stream.boxed())
2330// }
2331// .boxed()
2332// }
2333
2334// fn as_fake(&self) -> &FakeLanguageModel {
2335// &self.inner
2336// }
2337// }
2338
2339// #[gpui::test]
2340// async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
2341// init_test_settings(cx);
2342
2343// let project = create_test_project(cx, json!({})).await;
2344// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2345
2346// // Create model that returns overloaded error
2347// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2348
2349// // Insert a user message
2350// thread.update(cx, |thread, cx| {
2351// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2352// });
2353
2354// // Start completion
2355// thread.update(cx, |thread, cx| {
2356// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2357// });
2358
2359// cx.run_until_parked();
2360
2361// thread.read_with(cx, |thread, _| {
2362// assert!(thread.retry_state.is_some(), "Should have retry state");
2363// let retry_state = thread.retry_state.as_ref().unwrap();
2364// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2365// assert_eq!(
2366// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2367// "Should have default max attempts"
2368// );
2369// });
2370
2371// // Check that a retry message was added
2372// thread.read_with(cx, |thread, _| {
2373// let mut messages = thread.messages();
2374// assert!(
2375// messages.any(|msg| {
2376// msg.role == Role::System
2377// && msg.ui_only
2378// && msg.segments.iter().any(|seg| {
2379// if let MessageSegment::Text(text) = seg {
2380// text.contains("overloaded")
2381// && text
2382// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
2383// } else {
2384// false
2385// }
2386// })
2387// }),
2388// "Should have added a system retry message"
2389// );
2390// });
2391
2392// let retry_count = thread.update(cx, |thread, _| {
2393// thread
2394// .messages
2395// .iter()
2396// .filter(|m| {
2397// m.ui_only
2398// && m.segments.iter().any(|s| {
2399// if let MessageSegment::Text(text) = s {
2400// text.contains("Retrying") && text.contains("seconds")
2401// } else {
2402// false
2403// }
2404// })
2405// })
2406// .count()
2407// });
2408
2409// assert_eq!(retry_count, 1, "Should have one retry message");
2410// }
2411
2412// #[gpui::test]
2413// async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
2414// init_test_settings(cx);
2415
2416// let project = create_test_project(cx, json!({})).await;
2417// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2418
2419// // Create model that returns internal server error
2420// let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
2421
2422// // Insert a user message
2423// thread.update(cx, |thread, cx| {
2424// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2425// });
2426
2427// // Start completion
2428// thread.update(cx, |thread, cx| {
2429// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2430// });
2431
2432// cx.run_until_parked();
2433
2434// // Check retry state on thread
2435// thread.read_with(cx, |thread, _| {
2436// assert!(thread.retry_state.is_some(), "Should have retry state");
2437// let retry_state = thread.retry_state.as_ref().unwrap();
2438// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2439// assert_eq!(
2440// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2441// "Should have correct max attempts"
2442// );
2443// });
2444
2445// // Check that a retry message was added with provider name
2446// thread.read_with(cx, |thread, _| {
2447// let mut messages = thread.messages();
2448// assert!(
2449// messages.any(|msg| {
2450// msg.role == Role::System
2451// && msg.ui_only
2452// && msg.segments.iter().any(|seg| {
2453// if let MessageSegment::Text(text) = seg {
2454// text.contains("internal")
2455// && text.contains("Fake")
2456// && text
2457// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
2458// } else {
2459// false
2460// }
2461// })
2462// }),
2463// "Should have added a system retry message with provider name"
2464// );
2465// });
2466
2467// // Count retry messages
2468// let retry_count = thread.update(cx, |thread, _| {
2469// thread
2470// .messages
2471// .iter()
2472// .filter(|m| {
2473// m.ui_only
2474// && m.segments.iter().any(|s| {
2475// if let MessageSegment::Text(text) = s {
2476// text.contains("Retrying") && text.contains("seconds")
2477// } else {
2478// false
2479// }
2480// })
2481// })
2482// .count()
2483// });
2484
2485// assert_eq!(retry_count, 1, "Should have one retry message");
2486// }
2487
2488// #[gpui::test]
2489// async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
2490// init_test_settings(cx);
2491
2492// let project = create_test_project(cx, json!({})).await;
2493// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2494
2495// // Create model that returns overloaded error
2496// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2497
2498// // Insert a user message
2499// thread.update(cx, |thread, cx| {
2500// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2501// });
2502
2503// // Track retry events and completion count
2504// // Track completion events
2505// let completion_count = Arc::new(Mutex::new(0));
2506// let completion_count_clone = completion_count.clone();
2507
2508// let _subscription = thread.update(cx, |_, cx| {
2509// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2510// if let ThreadEvent::NewRequest = event {
2511// *completion_count_clone.lock() += 1;
2512// }
2513// })
2514// });
2515
2516// // First attempt
2517// thread.update(cx, |thread, cx| {
2518// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2519// });
2520// cx.run_until_parked();
2521
2522// // Should have scheduled first retry - count retry messages
2523// let retry_count = thread.update(cx, |thread, _| {
2524// thread
2525// .messages
2526// .iter()
2527// .filter(|m| {
2528// m.ui_only
2529// && m.segments.iter().any(|s| {
2530// if let MessageSegment::Text(text) = s {
2531// text.contains("Retrying") && text.contains("seconds")
2532// } else {
2533// false
2534// }
2535// })
2536// })
2537// .count()
2538// });
2539// assert_eq!(retry_count, 1, "Should have scheduled first retry");
2540
2541// // Check retry state
2542// thread.read_with(cx, |thread, _| {
2543// assert!(thread.retry_state.is_some(), "Should have retry state");
2544// let retry_state = thread.retry_state.as_ref().unwrap();
2545// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2546// });
2547
2548// // Advance clock for first retry
2549// cx.executor()
2550// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
2551// cx.run_until_parked();
2552
2553// // Should have scheduled second retry - count retry messages
2554// let retry_count = thread.update(cx, |thread, _| {
2555// thread
2556// .messages
2557// .iter()
2558// .filter(|m| {
2559// m.ui_only
2560// && m.segments.iter().any(|s| {
2561// if let MessageSegment::Text(text) = s {
2562// text.contains("Retrying") && text.contains("seconds")
2563// } else {
2564// false
2565// }
2566// })
2567// })
2568// .count()
2569// });
2570// assert_eq!(retry_count, 2, "Should have scheduled second retry");
2571
2572// // Check retry state updated
2573// thread.read_with(cx, |thread, _| {
2574// assert!(thread.retry_state.is_some(), "Should have retry state");
2575// let retry_state = thread.retry_state.as_ref().unwrap();
2576// assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
2577// assert_eq!(
2578// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2579// "Should have correct max attempts"
2580// );
2581// });
2582
2583// // Advance clock for second retry (exponential backoff)
2584// cx.executor()
2585// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
2586// cx.run_until_parked();
2587
2588// // Should have scheduled third retry
2589// // Count all retry messages now
2590// let retry_count = thread.update(cx, |thread, _| {
2591// thread
2592// .messages
2593// .iter()
2594// .filter(|m| {
2595// m.ui_only
2596// && m.segments.iter().any(|s| {
2597// if let MessageSegment::Text(text) = s {
2598// text.contains("Retrying") && text.contains("seconds")
2599// } else {
2600// false
2601// }
2602// })
2603// })
2604// .count()
2605// });
2606// assert_eq!(
2607// retry_count, MAX_RETRY_ATTEMPTS as usize,
2608// "Should have scheduled third retry"
2609// );
2610
2611// // Check retry state updated
2612// thread.read_with(cx, |thread, _| {
2613// assert!(thread.retry_state.is_some(), "Should have retry state");
2614// let retry_state = thread.retry_state.as_ref().unwrap();
2615// assert_eq!(
2616// retry_state.attempt, MAX_RETRY_ATTEMPTS,
2617// "Should be at max retry attempt"
2618// );
2619// assert_eq!(
2620// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2621// "Should have correct max attempts"
2622// );
2623// });
2624
2625// // Advance clock for third retry (exponential backoff)
2626// cx.executor()
2627// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
2628// cx.run_until_parked();
2629
2630// // No more retries should be scheduled after clock was advanced.
2631// let retry_count = thread.update(cx, |thread, _| {
2632// thread
2633// .messages
2634// .iter()
2635// .filter(|m| {
2636// m.ui_only
2637// && m.segments.iter().any(|s| {
2638// if let MessageSegment::Text(text) = s {
2639// text.contains("Retrying") && text.contains("seconds")
2640// } else {
2641// false
2642// }
2643// })
2644// })
2645// .count()
2646// });
2647// assert_eq!(
2648// retry_count, MAX_RETRY_ATTEMPTS as usize,
2649// "Should not exceed max retries"
2650// );
2651
2652// // Final completion count should be initial + max retries
2653// assert_eq!(
2654// *completion_count.lock(),
2655// (MAX_RETRY_ATTEMPTS + 1) as usize,
2656// "Should have made initial + max retry attempts"
2657// );
2658// }
2659
2660// #[gpui::test]
2661// async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
2662// init_test_settings(cx);
2663
2664// let project = create_test_project(cx, json!({})).await;
2665// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2666
2667// // Create model that returns overloaded error
2668// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2669
2670// // Insert a user message
2671// thread.update(cx, |thread, cx| {
2672// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2673// });
2674
2675// // Track events
2676// let retries_failed = Arc::new(Mutex::new(false));
2677// let retries_failed_clone = retries_failed.clone();
2678
2679// let _subscription = thread.update(cx, |_, cx| {
2680// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2681// if let ThreadEvent::RetriesFailed { .. } = event {
2682// *retries_failed_clone.lock() = true;
2683// }
2684// })
2685// });
2686
2687// // Start initial completion
2688// thread.update(cx, |thread, cx| {
2689// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2690// });
2691// cx.run_until_parked();
2692
2693// // Advance through all retries
2694// for i in 0..MAX_RETRY_ATTEMPTS {
2695// let delay = if i == 0 {
2696// BASE_RETRY_DELAY_SECS
2697// } else {
2698// BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
2699// };
2700// cx.executor().advance_clock(Duration::from_secs(delay));
2701// cx.run_until_parked();
2702// }
2703
2704// // After the 3rd retry is scheduled, we need to wait for it to execute and fail
2705// // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
2706// let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
2707// cx.executor()
2708// .advance_clock(Duration::from_secs(final_delay));
2709// cx.run_until_parked();
2710
2711// let retry_count = thread.update(cx, |thread, _| {
2712// thread
2713// .messages
2714// .iter()
2715// .filter(|m| {
2716// m.ui_only
2717// && m.segments.iter().any(|s| {
2718// if let MessageSegment::Text(text) = s {
2719// text.contains("Retrying") && text.contains("seconds")
2720// } else {
2721// false
2722// }
2723// })
2724// })
2725// .count()
2726// });
2727
2728// // After max retries, should emit RetriesFailed event
2729// assert_eq!(
2730// retry_count, MAX_RETRY_ATTEMPTS as usize,
2731// "Should have attempted max retries"
2732// );
2733// assert!(
2734// *retries_failed.lock(),
2735// "Should emit RetriesFailed event after max retries exceeded"
2736// );
2737
2738// // Retry state should be cleared
2739// thread.read_with(cx, |thread, _| {
2740// assert!(
2741// thread.retry_state.is_none(),
2742// "Retry state should be cleared after max retries"
2743// );
2744
2745// // Verify we have the expected number of retry messages
2746// let retry_messages = thread
2747// .messages
2748// .iter()
2749// .filter(|msg| msg.ui_only && msg.role == Role::System)
2750// .count();
2751// assert_eq!(
2752// retry_messages, MAX_RETRY_ATTEMPTS as usize,
2753// "Should have one retry message per attempt"
2754// );
2755// });
2756// }
2757
2758// #[gpui::test]
2759// async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
2760// init_test_settings(cx);
2761
2762// let project = create_test_project(cx, json!({})).await;
2763// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2764
2765// // We'll use a wrapper to switch behavior after first failure
2766// struct RetryTestModel {
2767// inner: Arc<FakeLanguageModel>,
2768// failed_once: Arc<Mutex<bool>>,
2769// }
2770
2771// impl LanguageModel for RetryTestModel {
2772// fn id(&self) -> LanguageModelId {
2773// self.inner.id()
2774// }
2775
2776// fn name(&self) -> LanguageModelName {
2777// self.inner.name()
2778// }
2779
2780// fn provider_id(&self) -> LanguageModelProviderId {
2781// self.inner.provider_id()
2782// }
2783
2784// fn provider_name(&self) -> LanguageModelProviderName {
2785// self.inner.provider_name()
2786// }
2787
2788// fn supports_tools(&self) -> bool {
2789// self.inner.supports_tools()
2790// }
2791
2792// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2793// self.inner.supports_tool_choice(choice)
2794// }
2795
2796// fn supports_images(&self) -> bool {
2797// self.inner.supports_images()
2798// }
2799
2800// fn telemetry_id(&self) -> String {
2801// self.inner.telemetry_id()
2802// }
2803
2804// fn max_token_count(&self) -> u64 {
2805// self.inner.max_token_count()
2806// }
2807
2808// fn count_tokens(
2809// &self,
2810// request: LanguageModelRequest,
2811// cx: &App,
2812// ) -> BoxFuture<'static, Result<u64>> {
2813// self.inner.count_tokens(request, cx)
2814// }
2815
2816// fn stream_completion(
2817// &self,
2818// request: LanguageModelRequest,
2819// cx: &AsyncApp,
2820// ) -> BoxFuture<
2821// 'static,
2822// Result<
2823// BoxStream<
2824// 'static,
2825// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
2826// >,
2827// LanguageModelCompletionError,
2828// >,
2829// > {
2830// if !*self.failed_once.lock() {
2831// *self.failed_once.lock() = true;
2832// // Return error on first attempt
2833// let stream = futures::stream::once(async move {
2834// Err(LanguageModelCompletionError::Overloaded)
2835// });
2836// async move { Ok(stream.boxed()) }.boxed()
2837// } else {
2838// // Succeed on retry
2839// self.inner.stream_completion(request, cx)
2840// }
2841// }
2842
2843// fn as_fake(&self) -> &FakeLanguageModel {
2844// &self.inner
2845// }
2846// }
2847
2848// let model = Arc::new(RetryTestModel {
2849// inner: Arc::new(FakeLanguageModel::default()),
2850// failed_once: Arc::new(Mutex::new(false)),
2851// });
2852
2853// // Insert a user message
2854// thread.update(cx, |thread, cx| {
2855// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2856// });
2857
2858// // Track message deletions
2859// // Track when retry completes successfully
2860// let retry_completed = Arc::new(Mutex::new(false));
2861// let retry_completed_clone = retry_completed.clone();
2862
2863// let _subscription = thread.update(cx, |_, cx| {
2864// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2865// if let ThreadEvent::StreamedCompletion = event {
2866// *retry_completed_clone.lock() = true;
2867// }
2868// })
2869// });
2870
2871// // Start completion
2872// thread.update(cx, |thread, cx| {
2873// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2874// });
2875// cx.run_until_parked();
2876
2877// // Get the retry message ID
2878// let retry_message_id = thread.read_with(cx, |thread, _| {
2879// thread
2880// .messages()
2881// .find(|msg| msg.role == Role::System && msg.ui_only)
2882// .map(|msg| msg.id)
2883// .expect("Should have a retry message")
2884// });
2885
2886// // Wait for retry
2887// cx.executor()
2888// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
2889// cx.run_until_parked();
2890
2891// // Stream some successful content
2892// let fake_model = model.as_fake();
2893// // After the retry, there should be a new pending completion
2894// let pending = fake_model.pending_completions();
2895// assert!(
2896// !pending.is_empty(),
2897// "Should have a pending completion after retry"
2898// );
2899// fake_model.stream_completion_response(&pending[0], "Success!");
2900// fake_model.end_completion_stream(&pending[0]);
2901// cx.run_until_parked();
2902
2903// // Check that the retry completed successfully
2904// assert!(
2905// *retry_completed.lock(),
2906// "Retry should have completed successfully"
2907// );
2908
2909// // Retry message should still exist but be marked as ui_only
2910// thread.read_with(cx, |thread, _| {
2911// let retry_msg = thread
2912// .message(retry_message_id)
2913// .expect("Retry message should still exist");
2914// assert!(retry_msg.ui_only, "Retry message should be ui_only");
2915// assert_eq!(
2916// retry_msg.role,
2917// Role::System,
2918// "Retry message should have System role"
2919// );
2920// });
2921// }
2922
2923// #[gpui::test]
2924// async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
2925// init_test_settings(cx);
2926
2927// let project = create_test_project(cx, json!({})).await;
2928// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2929
2930// // Create a model that fails once then succeeds
2931// struct FailOnceModel {
2932// inner: Arc<FakeLanguageModel>,
2933// failed_once: Arc<Mutex<bool>>,
2934// }
2935
2936// impl LanguageModel for FailOnceModel {
2937// fn id(&self) -> LanguageModelId {
2938// self.inner.id()
2939// }
2940
2941// fn name(&self) -> LanguageModelName {
2942// self.inner.name()
2943// }
2944
2945// fn provider_id(&self) -> LanguageModelProviderId {
2946// self.inner.provider_id()
2947// }
2948
2949// fn provider_name(&self) -> LanguageModelProviderName {
2950// self.inner.provider_name()
2951// }
2952
2953// fn supports_tools(&self) -> bool {
2954// self.inner.supports_tools()
2955// }
2956
2957// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2958// self.inner.supports_tool_choice(choice)
2959// }
2960
2961// fn supports_images(&self) -> bool {
2962// self.inner.supports_images()
2963// }
2964
2965// fn telemetry_id(&self) -> String {
2966// self.inner.telemetry_id()
2967// }
2968
2969// fn max_token_count(&self) -> u64 {
2970// self.inner.max_token_count()
2971// }
2972
2973// fn count_tokens(
2974// &self,
2975// request: LanguageModelRequest,
2976// cx: &App,
2977// ) -> BoxFuture<'static, Result<u64>> {
2978// self.inner.count_tokens(request, cx)
2979// }
2980
2981// fn stream_completion(
2982// &self,
2983// request: LanguageModelRequest,
2984// cx: &AsyncApp,
2985// ) -> BoxFuture<
2986// 'static,
2987// Result<
2988// BoxStream<
2989// 'static,
2990// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
2991// >,
2992// LanguageModelCompletionError,
2993// >,
2994// > {
2995// if !*self.failed_once.lock() {
2996// *self.failed_once.lock() = true;
2997// // Return error on first attempt
2998// let stream = futures::stream::once(async move {
2999// Err(LanguageModelCompletionError::Overloaded)
3000// });
3001// async move { Ok(stream.boxed()) }.boxed()
3002// } else {
3003// // Succeed on retry
3004// self.inner.stream_completion(request, cx)
3005// }
3006// }
3007// }
3008
3009// let fail_once_model = Arc::new(FailOnceModel {
3010// inner: Arc::new(FakeLanguageModel::default()),
3011// failed_once: Arc::new(Mutex::new(false)),
3012// });
3013
3014// // Insert a user message
3015// thread.update(cx, |thread, cx| {
3016// thread.insert_user_message(
3017// "Test message",
3018// ContextLoadResult::default(),
3019// None,
3020// vec![],
3021// cx,
3022// );
3023// });
3024
3025// // Start completion with fail-once model
3026// thread.update(cx, |thread, cx| {
3027// thread.send_to_model(
3028// fail_once_model.clone(),
3029// CompletionIntent::UserPrompt,
3030// None,
3031// cx,
3032// );
3033// });
3034
3035// cx.run_until_parked();
3036
3037// // Verify retry state exists after first failure
3038// thread.read_with(cx, |thread, _| {
3039// assert!(
3040// thread.retry_state.is_some(),
3041// "Should have retry state after failure"
3042// );
3043// });
3044
3045// // Wait for retry delay
3046// cx.executor()
3047// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
3048// cx.run_until_parked();
3049
3050// // The retry should now use our FailOnceModel which should succeed
3051// // We need to help the FakeLanguageModel complete the stream
3052// let inner_fake = fail_once_model.inner.clone();
3053
3054// // Wait a bit for the retry to start
3055// cx.run_until_parked();
3056
3057// // Check for pending completions and complete them
3058// if let Some(pending) = inner_fake.pending_completions().first() {
3059// inner_fake.stream_completion_response(pending, "Success!");
3060// inner_fake.end_completion_stream(pending);
3061// }
3062// cx.run_until_parked();
3063
3064// thread.read_with(cx, |thread, _| {
3065// assert!(
3066// thread.retry_state.is_none(),
3067// "Retry state should be cleared after successful completion"
3068// );
3069
3070// let has_assistant_message = thread
3071// .messages
3072// .iter()
3073// .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
3074// assert!(
3075// has_assistant_message,
3076// "Should have an assistant message after successful retry"
3077// );
3078// });
3079// }
3080
3081// #[gpui::test]
3082// async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
3083// init_test_settings(cx);
3084
3085// let project = create_test_project(cx, json!({})).await;
3086// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3087
3088// // Create a model that returns rate limit error with retry_after
3089// struct RateLimitModel {
3090// inner: Arc<FakeLanguageModel>,
3091// }
3092
3093// impl LanguageModel for RateLimitModel {
3094// fn id(&self) -> LanguageModelId {
3095// self.inner.id()
3096// }
3097
3098// fn name(&self) -> LanguageModelName {
3099// self.inner.name()
3100// }
3101
3102// fn provider_id(&self) -> LanguageModelProviderId {
3103// self.inner.provider_id()
3104// }
3105
3106// fn provider_name(&self) -> LanguageModelProviderName {
3107// self.inner.provider_name()
3108// }
3109
3110// fn supports_tools(&self) -> bool {
3111// self.inner.supports_tools()
3112// }
3113
3114// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
3115// self.inner.supports_tool_choice(choice)
3116// }
3117
3118// fn supports_images(&self) -> bool {
3119// self.inner.supports_images()
3120// }
3121
3122// fn telemetry_id(&self) -> String {
3123// self.inner.telemetry_id()
3124// }
3125
3126// fn max_token_count(&self) -> u64 {
3127// self.inner.max_token_count()
3128// }
3129
3130// fn count_tokens(
3131// &self,
3132// request: LanguageModelRequest,
3133// cx: &App,
3134// ) -> BoxFuture<'static, Result<u64>> {
3135// self.inner.count_tokens(request, cx)
3136// }
3137
3138// fn stream_completion(
3139// &self,
3140// _request: LanguageModelRequest,
3141// _cx: &AsyncApp,
3142// ) -> BoxFuture<
3143// 'static,
3144// Result<
3145// BoxStream<
3146// 'static,
3147// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
3148// >,
3149// LanguageModelCompletionError,
3150// >,
3151// > {
3152// async move {
3153// let stream = futures::stream::once(async move {
3154// Err(LanguageModelCompletionError::RateLimitExceeded {
3155// retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
3156// })
3157// });
3158// Ok(stream.boxed())
3159// }
3160// .boxed()
3161// }
3162
3163// fn as_fake(&self) -> &FakeLanguageModel {
3164// &self.inner
3165// }
3166// }
3167
3168// let model = Arc::new(RateLimitModel {
3169// inner: Arc::new(FakeLanguageModel::default()),
3170// });
3171
3172// // Insert a user message
3173// thread.update(cx, |thread, cx| {
3174// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3175// });
3176
3177// // Start completion
3178// thread.update(cx, |thread, cx| {
3179// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3180// });
3181
3182// cx.run_until_parked();
3183
3184// let retry_count = thread.update(cx, |thread, _| {
3185// thread
3186// .messages
3187// .iter()
3188// .filter(|m| {
3189// m.ui_only
3190// && m.segments.iter().any(|s| {
3191// if let MessageSegment::Text(text) = s {
3192// text.contains("rate limit exceeded")
3193// } else {
3194// false
3195// }
3196// })
3197// })
3198// .count()
3199// });
3200// assert_eq!(retry_count, 1, "Should have scheduled one retry");
3201
3202// thread.read_with(cx, |thread, _| {
3203// assert!(
3204// thread.retry_state.is_none(),
3205// "Rate limit errors should not set retry_state"
3206// );
3207// });
3208
3209// // Verify we have one retry message
3210// thread.read_with(cx, |thread, _| {
3211// let retry_messages = thread
3212// .messages
3213// .iter()
3214// .filter(|msg| {
3215// msg.ui_only
3216// && msg.segments.iter().any(|seg| {
3217// if let MessageSegment::Text(text) = seg {
3218// text.contains("rate limit exceeded")
3219// } else {
3220// false
3221// }
3222// })
3223// })
3224// .count();
3225// assert_eq!(
3226// retry_messages, 1,
3227// "Should have one rate limit retry message"
3228// );
3229// });
3230
3231// // Check that retry message doesn't include attempt count
3232// thread.read_with(cx, |thread, _| {
3233// let retry_message = thread
3234// .messages
3235// .iter()
3236// .find(|msg| msg.role == Role::System && msg.ui_only)
3237// .expect("Should have a retry message");
3238
3239// // Check that the message doesn't contain attempt count
3240// if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
3241// assert!(
3242// !text.contains("attempt"),
3243// "Rate limit retry message should not contain attempt count"
3244// );
3245// assert!(
3246// text.contains(&format!(
3247// "Retrying in {} seconds",
3248// TEST_RATE_LIMIT_RETRY_SECS
3249// )),
3250// "Rate limit retry message should contain retry delay"
3251// );
3252// }
3253// });
3254// }
3255
3256// #[gpui::test]
3257// async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
3258// init_test_settings(cx);
3259
3260// let project = create_test_project(cx, json!({})).await;
3261// let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
3262
3263// // Insert a regular user message
3264// thread.update(cx, |thread, cx| {
3265// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3266// });
3267
3268// // Insert a UI-only message (like our retry notifications)
3269// thread.update(cx, |thread, cx| {
3270// let id = thread.next_message_id.post_inc();
3271// thread.messages.push(Message {
3272// id,
3273// role: Role::System,
3274// segments: vec![MessageSegment::Text(
3275// "This is a UI-only message that should not be sent to the model".to_string(),
3276// )],
3277// loaded_context: LoadedContext::default(),
3278// creases: Vec::new(),
3279// is_hidden: true,
3280// ui_only: true,
3281// });
3282// cx.emit(ThreadEvent::MessageAdded(id));
3283// });
3284
3285// // Insert another regular message
3286// thread.update(cx, |thread, cx| {
3287// thread.insert_user_message(
3288// "How are you?",
3289// ContextLoadResult::default(),
3290// None,
3291// vec![],
3292// cx,
3293// );
3294// });
3295
3296// // Generate the completion request
3297// let request = thread.update(cx, |thread, cx| {
3298// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3299// });
3300
3301// // Verify that the request only contains non-UI-only messages
3302// // Should have system prompt + 2 user messages, but not the UI-only message
3303// let user_messages: Vec<_> = request
3304// .messages
3305// .iter()
3306// .filter(|msg| msg.role == Role::User)
3307// .collect();
3308// assert_eq!(
3309// user_messages.len(),
3310// 2,
3311// "Should have exactly 2 user messages"
3312// );
3313
3314// // Verify the UI-only content is not present anywhere in the request
3315// let request_text = request
3316// .messages
3317// .iter()
3318// .flat_map(|msg| &msg.content)
3319// .filter_map(|content| match content {
3320// MessageContent::Text(text) => Some(text.as_str()),
3321// _ => None,
3322// })
3323// .collect::<String>();
3324
3325// assert!(
3326// !request_text.contains("UI-only message"),
3327// "UI-only message content should not be in the request"
3328// );
3329
3330// // Verify the thread still has all 3 messages (including UI-only)
3331// thread.read_with(cx, |thread, _| {
3332// assert_eq!(
3333// thread.messages().count(),
3334// 3,
3335// "Thread should have 3 messages"
3336// );
3337// assert_eq!(
3338// thread.messages().filter(|m| m.ui_only).count(),
3339// 1,
3340// "Thread should have 1 UI-only message"
3341// );
3342// });
3343
3344// // Verify that UI-only messages are not serialized
3345// let serialized = thread
3346// .update(cx, |thread, cx| thread.serialize(cx))
3347// .await
3348// .unwrap();
3349// assert_eq!(
3350// serialized.messages.len(),
3351// 2,
3352// "Serialized thread should only have 2 messages (no UI-only)"
3353// );
3354// }
3355
3356// #[gpui::test]
3357// async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
3358// init_test_settings(cx);
3359
3360// let project = create_test_project(cx, json!({})).await;
3361// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3362
3363// // Create model that returns overloaded error
3364// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
3365
3366// // Insert a user message
3367// thread.update(cx, |thread, cx| {
3368// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3369// });
3370
3371// // Start completion
3372// thread.update(cx, |thread, cx| {
3373// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3374// });
3375
3376// cx.run_until_parked();
3377
3378// // Verify retry was scheduled by checking for retry message
3379// let has_retry_message = thread.read_with(cx, |thread, _| {
3380// thread.messages.iter().any(|m| {
3381// m.ui_only
3382// && m.segments.iter().any(|s| {
3383// if let MessageSegment::Text(text) = s {
3384// text.contains("Retrying") && text.contains("seconds")
3385// } else {
3386// false
3387// }
3388// })
3389// })
3390// });
3391// assert!(has_retry_message, "Should have scheduled a retry");
3392
3393// // Cancel the completion before the retry happens
3394// thread.update(cx, |thread, cx| {
3395// thread.cancel_last_completion(None, cx);
3396// });
3397
3398// cx.run_until_parked();
3399
3400// // The retry should not have happened - no pending completions
3401// let fake_model = model.as_fake();
3402// assert_eq!(
3403// fake_model.pending_completions().len(),
3404// 0,
3405// "Should have no pending completions after cancellation"
3406// );
3407
3408// // Verify the retry was cancelled by checking retry state
3409// thread.read_with(cx, |thread, _| {
3410// if let Some(retry_state) = &thread.retry_state {
3411// panic!(
3412// "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
3413// retry_state.attempt, retry_state.max_attempts, retry_state.intent
3414// );
3415// }
3416// });
3417// }
3418
3419// fn test_summarize_error(
3420// model: &Arc<dyn LanguageModel>,
3421// thread: &Entity<Thread>,
3422// cx: &mut TestAppContext,
3423// ) {
3424// thread.update(cx, |thread, cx| {
3425// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3426// thread.send_to_model(
3427// model.clone(),
3428// CompletionIntent::ThreadSummarization,
3429// None,
3430// cx,
3431// );
3432// });
3433
3434// let fake_model = model.as_fake();
3435// simulate_successful_response(&fake_model, cx);
3436
3437// thread.read_with(cx, |thread, _| {
3438// assert!(matches!(thread.summary(), ThreadSummary::Generating));
3439// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3440// });
3441
3442// // Simulate summary request ending
3443// cx.run_until_parked();
3444// fake_model.end_last_completion_stream();
3445// cx.run_until_parked();
3446
3447// // State is set to Error and default message
3448// thread.read_with(cx, |thread, _| {
3449// assert!(matches!(thread.summary(), ThreadSummary::Error));
3450// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3451// });
3452// }
3453
3454// fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3455// cx.run_until_parked();
3456// fake_model.stream_last_completion_response("Assistant response");
3457// fake_model.end_last_completion_stream();
3458// cx.run_until_parked();
3459// }
3460
3461// fn init_test_settings(cx: &mut TestAppContext) {
3462// cx.update(|cx| {
3463// let settings_store = SettingsStore::test(cx);
3464// cx.set_global(settings_store);
3465// language::init(cx);
3466// Project::init_settings(cx);
3467// AgentSettings::register(cx);
3468// prompt_store::init(cx);
3469// thread_store::init(cx);
3470// workspace::init_settings(cx);
3471// language_model::init_settings(cx);
3472// ThemeSettings::register(cx);
3473// ToolRegistry::default_global(cx);
3474// });
3475// }
3476
3477// // Helper to create a test project with test files
3478// async fn create_test_project(
3479// cx: &mut TestAppContext,
3480// files: serde_json::Value,
3481// ) -> Entity<Project> {
3482// let fs = FakeFs::new(cx.executor());
3483// fs.insert_tree(path!("/test"), files).await;
3484// Project::test(fs, [path!("/test").as_ref()], cx).await
3485// }
3486
3487// async fn setup_test_environment(
3488// cx: &mut TestAppContext,
3489// project: Entity<Project>,
3490// ) -> (
3491// Entity<Workspace>,
3492// Entity<ThreadStore>,
3493// Entity<Thread>,
3494// Entity<ContextStore>,
3495// Arc<dyn LanguageModel>,
3496// ) {
3497// let (workspace, cx) =
3498// cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3499
3500// let thread_store = cx
3501// .update(|_, cx| {
3502// ThreadStore::load(
3503// project.clone(),
3504// cx.new(|_| ToolWorkingSet::default()),
3505// None,
3506// Arc::new(PromptBuilder::new(None).unwrap()),
3507// cx,
3508// )
3509// })
3510// .await
3511// .unwrap();
3512
3513// let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3514// let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3515
3516// let provider = Arc::new(FakeLanguageModelProvider);
3517// let model = provider.test_model();
3518// let model: Arc<dyn LanguageModel> = Arc::new(model);
3519
3520// cx.update(|_, cx| {
3521// LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3522// registry.set_default_model(
3523// Some(ConfiguredModel {
3524// provider: provider.clone(),
3525// model: model.clone(),
3526// }),
3527// cx,
3528// );
3529// registry.set_thread_summary_model(
3530// Some(ConfiguredModel {
3531// provider,
3532// model: model.clone(),
3533// }),
3534// cx,
3535// );
3536// })
3537// });
3538
3539// (workspace, thread_store, thread, context_store, model)
3540// }
3541
3542// async fn add_file_to_context(
3543// project: &Entity<Project>,
3544// context_store: &Entity<ContextStore>,
3545// path: &str,
3546// cx: &mut TestAppContext,
3547// ) -> Result<Entity<language::Buffer>> {
3548// let buffer_path = project
3549// .read_with(cx, |project, cx| project.find_project_path(path, cx))
3550// .unwrap();
3551
3552// let buffer = project
3553// .update(cx, |project, cx| {
3554// project.open_buffer(buffer_path.clone(), cx)
3555// })
3556// .await
3557// .unwrap();
3558
3559// context_store.update(cx, |context_store, cx| {
3560// context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3561// });
3562
3563// Ok(buffer)
3564// }
3565// }