1use crate::{
2 AgentThread, AgentThreadUserMessageChunk, MessageId, ThreadId,
3 agent_profile::AgentProfile,
4 context::{AgentContextHandle, LoadedContext},
5 thread_store::{SharedProjectContext, ThreadStore},
6};
7use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::Result;
9use assistant_tool::{ActionLog, AnyToolCard, Tool};
10use chrono::{DateTime, Utc};
11use client::{ModelRequestUsage, RequestUsage};
12use collections::{HashMap, HashSet};
13use futures::{FutureExt, channel::oneshot, future::Shared};
14use git::repository::DiffType;
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
17 Window,
18};
19use language_model::{
20 ConfiguredModel, LanguageModelId, LanguageModelToolUseId, Role, StopReason, TokenUsage,
21};
22use markdown::Markdown;
23use postage::stream::Stream as _;
24use project::{
25 Project,
26 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
27};
28use proto::Plan;
29use serde::{Deserialize, Serialize};
30use settings::Settings;
31use std::{ops::Range, sync::Arc, time::Instant};
32use thiserror::Error;
33use util::ResultExt as _;
34use zed_llm_client::UsageLimit;
35
36/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
37#[derive(Clone, Debug)]
38pub struct MessageCrease {
39 pub range: Range<usize>,
40 pub icon_path: SharedString,
41 pub label: SharedString,
42 /// None for a deserialized message, Some otherwise.
43 pub context: Option<AgentContextHandle>,
44}
45
46pub enum MessageToolCall {
47 Pending {
48 tool: Arc<dyn Tool>,
49 input: serde_json::Value,
50 },
51 NeedsConfirmation {
52 tool: Arc<dyn Tool>,
53 input_json: serde_json::Value,
54 confirm_tx: oneshot::Sender<bool>,
55 },
56 Confirmed {
57 card: AnyToolCard,
58 },
59 Declined {
60 tool: Arc<dyn Tool>,
61 input_json: serde_json::Value,
62 },
63}
64
65pub enum Message {
66 User {
67 id: MessageId,
68 text: String,
69 creases: Vec<MessageCrease>,
70 loaded_context: LoadedContext,
71 },
72 Assistant {
73 id: MessageId,
74 chunks: Vec<AssistantMessageChunk>,
75 },
76}
77
78impl Message {
79 pub fn id(&self) -> MessageId {
80 match self {
81 Message::User { id, .. } => *id,
82 Message::Assistant { id, .. } => *id,
83 }
84 }
85}
86
87pub enum AssistantMessageChunk {
88 Text(Entity<Markdown>),
89 Thinking(Entity<Markdown>),
90 ToolCall(MessageToolCall),
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
94pub struct ProjectSnapshot {
95 pub worktree_snapshots: Vec<WorktreeSnapshot>,
96 pub unsaved_buffer_paths: Vec<String>,
97 pub timestamp: DateTime<Utc>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
101pub struct WorktreeSnapshot {
102 pub worktree_path: String,
103 pub git_state: Option<GitState>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub struct GitState {
108 pub remote_url: Option<String>,
109 pub head_sha: Option<String>,
110 pub current_branch: Option<String>,
111 pub diff: Option<String>,
112}
113
114#[derive(Clone, Debug)]
115pub struct ThreadCheckpoint {
116 message_id: MessageId,
117 git_checkpoint: GitStoreCheckpoint,
118}
119
120#[derive(Copy, Clone, Debug, PartialEq, Eq)]
121pub enum ThreadFeedback {
122 Positive,
123 Negative,
124}
125
126pub enum LastRestoreCheckpoint {
127 Pending {
128 message_id: MessageId,
129 },
130 Error {
131 message_id: MessageId,
132 error: String,
133 },
134}
135
136impl LastRestoreCheckpoint {
137 pub fn message_id(&self) -> MessageId {
138 match self {
139 LastRestoreCheckpoint::Pending { message_id } => *message_id,
140 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
141 }
142 }
143}
144
145#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
146pub enum DetailedSummaryState {
147 #[default]
148 NotGenerated,
149 Generating {
150 message_id: MessageId,
151 },
152 Generated {
153 text: SharedString,
154 message_id: MessageId,
155 },
156}
157
158impl DetailedSummaryState {
159 fn text(&self) -> Option<SharedString> {
160 if let Self::Generated { text, .. } = self {
161 Some(text.clone())
162 } else {
163 None
164 }
165 }
166}
167
168#[derive(Default, Debug)]
169pub struct TotalTokenUsage {
170 pub total: u64,
171 pub max: u64,
172}
173
174impl TotalTokenUsage {
175 pub fn ratio(&self) -> TokenUsageRatio {
176 #[cfg(debug_assertions)]
177 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
178 .unwrap_or("0.8".to_string())
179 .parse()
180 .unwrap();
181 #[cfg(not(debug_assertions))]
182 let warning_threshold: f32 = 0.8;
183
184 // When the maximum is unknown because there is no selected model,
185 // avoid showing the token limit warning.
186 if self.max == 0 {
187 TokenUsageRatio::Normal
188 } else if self.total >= self.max {
189 TokenUsageRatio::Exceeded
190 } else if self.total as f32 / self.max as f32 >= warning_threshold {
191 TokenUsageRatio::Warning
192 } else {
193 TokenUsageRatio::Normal
194 }
195 }
196
197 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
198 TotalTokenUsage {
199 total: self.total + tokens,
200 max: self.max,
201 }
202 }
203}
204
205#[derive(Debug, Default, PartialEq, Eq)]
206pub enum TokenUsageRatio {
207 #[default]
208 Normal,
209 Warning,
210 Exceeded,
211}
212
213#[derive(Debug, Clone, Copy)]
214pub enum QueueState {
215 Sending,
216 Queued { position: usize },
217 Started,
218}
219
220/// A thread of conversation with the LLM.
221pub struct Thread {
222 agent_thread: Arc<dyn AgentThread>,
223 title: ThreadTitle,
224 pending_send: Option<Task<Result<()>>>,
225 pending_summary: Task<Option<()>>,
226 detailed_summary_task: Task<Option<()>>,
227 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
228 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
229 completion_mode: agent_settings::CompletionMode,
230 messages: Vec<Message>,
231 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
232 project: Entity<Project>,
233 action_log: Entity<ActionLog>,
234 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
235 pending_checkpoint: Option<ThreadCheckpoint>,
236 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
237 request_token_usage: Vec<TokenUsage>,
238 cumulative_token_usage: TokenUsage,
239 exceeded_window_error: Option<ExceededWindowError>,
240 tool_use_limit_reached: bool,
241 // todo!(keep track of retries from the underlying agent)
242 feedback: Option<ThreadFeedback>,
243 message_feedback: HashMap<MessageId, ThreadFeedback>,
244 last_auto_capture_at: Option<Instant>,
245 last_received_chunk_at: Option<Instant>,
246}
247
248#[derive(Clone, Debug, PartialEq, Eq)]
249pub enum ThreadTitle {
250 Pending,
251 Generating,
252 Ready(SharedString),
253 Error,
254}
255
256impl ThreadTitle {
257 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
258
259 pub fn or_default(&self) -> SharedString {
260 self.unwrap_or(Self::DEFAULT)
261 }
262
263 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
264 self.ready().unwrap_or_else(|| message.into())
265 }
266
267 pub fn ready(&self) -> Option<SharedString> {
268 match self {
269 ThreadTitle::Ready(summary) => Some(summary.clone()),
270 ThreadTitle::Pending | ThreadTitle::Generating | ThreadTitle::Error => None,
271 }
272 }
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
276pub struct ExceededWindowError {
277 /// Model used when last message exceeded context window
278 model_id: LanguageModelId,
279 /// Token count including last message
280 token_count: u64,
281}
282
283impl Thread {
284 pub fn load(
285 agent_thread: Arc<dyn AgentThread>,
286 project: Entity<Project>,
287 cx: &mut Context<Self>,
288 ) -> Self {
289 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
290 Self {
291 agent_thread,
292 title: ThreadTitle::Pending,
293 pending_send: None,
294 pending_summary: Task::ready(None),
295 detailed_summary_task: Task::ready(None),
296 detailed_summary_tx,
297 detailed_summary_rx,
298 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
299 messages: todo!("read from agent"),
300 checkpoints_by_message: HashMap::default(),
301 project: project.clone(),
302 last_restore_checkpoint: None,
303 pending_checkpoint: None,
304 action_log: cx.new(|_| ActionLog::new(project.clone())),
305 initial_project_snapshot: {
306 let project_snapshot = Self::project_snapshot(project, cx);
307 cx.foreground_executor()
308 .spawn(async move { Some(project_snapshot.await) })
309 .shared()
310 },
311 request_token_usage: Vec::new(),
312 cumulative_token_usage: TokenUsage::default(),
313 exceeded_window_error: None,
314 tool_use_limit_reached: false,
315 feedback: None,
316 message_feedback: HashMap::default(),
317 last_auto_capture_at: None,
318 last_received_chunk_at: None,
319 }
320 }
321
322 pub fn id(&self) -> ThreadId {
323 self.agent_thread.id()
324 }
325
326 pub fn profile(&self) -> &AgentProfile {
327 todo!()
328 }
329
330 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
331 todo!()
332 // if &id != self.profile.id() {
333 // self.profile = AgentProfile::new(id, self.tools.clone());
334 // cx.emit(ThreadEvent::ProfileChanged);
335 // }
336 }
337
338 pub fn is_empty(&self) -> bool {
339 self.messages.is_empty()
340 }
341
342 pub fn project_context(&self) -> SharedProjectContext {
343 todo!()
344 // self.project_context.clone()
345 }
346
347 pub fn title(&self) -> &ThreadTitle {
348 &self.title
349 }
350
351 pub fn set_title(&mut self, new_title: impl Into<SharedString>, cx: &mut Context<Self>) {
352 todo!()
353 // let current_summary = match &self.summary {
354 // ThreadSummary::Pending | ThreadSummary::Generating => return,
355 // ThreadSummary::Ready(summary) => summary,
356 // ThreadSummary::Error => &ThreadSummary::DEFAULT,
357 // };
358
359 // let mut new_summary = new_summary.into();
360
361 // if new_summary.is_empty() {
362 // new_summary = ThreadSummary::DEFAULT;
363 // }
364
365 // if current_summary != &new_summary {
366 // self.summary = ThreadSummary::Ready(new_summary);
367 // cx.emit(ThreadEvent::SummaryChanged);
368 // }
369 }
370
371 pub fn regenerate_summary(&self, cx: &mut Context<Self>) {
372 todo!()
373 // self.summarize(cx);
374 }
375
376 pub fn completion_mode(&self) -> CompletionMode {
377 self.completion_mode
378 }
379
380 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
381 self.completion_mode = mode;
382 }
383
384 pub fn messages(&self) -> &[Message] {
385 &self.messages
386 }
387
388 pub fn is_generating(&self) -> bool {
389 self.pending_send.is_some()
390 }
391
392 /// Indicates whether streaming of language model events is stale.
393 /// When `is_generating()` is false, this method returns `None`.
394 pub fn is_generation_stale(&self) -> Option<bool> {
395 const STALE_THRESHOLD: u128 = 250;
396
397 self.last_received_chunk_at
398 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
399 }
400
401 fn received_chunk(&mut self) {
402 self.last_received_chunk_at = Some(Instant::now());
403 }
404
405 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
406 self.checkpoints_by_message.get(&id).cloned()
407 }
408
409 pub fn restore_checkpoint(
410 &mut self,
411 checkpoint: ThreadCheckpoint,
412 cx: &mut Context<Self>,
413 ) -> Task<Result<()>> {
414 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
415 message_id: checkpoint.message_id,
416 });
417 cx.emit(ThreadEvent::CheckpointChanged);
418 cx.notify();
419
420 let git_store = self.project().read(cx).git_store().clone();
421 let restore = git_store.update(cx, |git_store, cx| {
422 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
423 });
424
425 cx.spawn(async move |this, cx| {
426 let result = restore.await;
427 this.update(cx, |this, cx| {
428 if let Err(err) = result.as_ref() {
429 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
430 message_id: checkpoint.message_id,
431 error: err.to_string(),
432 });
433 } else {
434 this.truncate(checkpoint.message_id, cx);
435 this.last_restore_checkpoint = None;
436 }
437 this.pending_checkpoint = None;
438 cx.emit(ThreadEvent::CheckpointChanged);
439 cx.notify();
440 })?;
441 result
442 })
443 }
444
445 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
446 let pending_checkpoint = if self.is_generating() {
447 return;
448 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
449 checkpoint
450 } else {
451 return;
452 };
453
454 self.finalize_checkpoint(pending_checkpoint, cx);
455 }
456
457 fn finalize_checkpoint(
458 &mut self,
459 pending_checkpoint: ThreadCheckpoint,
460 cx: &mut Context<Self>,
461 ) {
462 let git_store = self.project.read(cx).git_store().clone();
463 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
464 cx.spawn(async move |this, cx| match final_checkpoint.await {
465 Ok(final_checkpoint) => {
466 let equal = git_store
467 .update(cx, |store, cx| {
468 store.compare_checkpoints(
469 pending_checkpoint.git_checkpoint.clone(),
470 final_checkpoint.clone(),
471 cx,
472 )
473 })?
474 .await
475 .unwrap_or(false);
476
477 if !equal {
478 this.update(cx, |this, cx| {
479 this.insert_checkpoint(pending_checkpoint, cx)
480 })?;
481 }
482
483 Ok(())
484 }
485 Err(_) => this.update(cx, |this, cx| {
486 this.insert_checkpoint(pending_checkpoint, cx)
487 }),
488 })
489 .detach();
490 }
491
492 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
493 self.checkpoints_by_message
494 .insert(checkpoint.message_id, checkpoint);
495 cx.emit(ThreadEvent::CheckpointChanged);
496 cx.notify();
497 }
498
499 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
500 self.last_restore_checkpoint.as_ref()
501 }
502
503 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
504 todo!("call truncate on the agent");
505 let Some(message_ix) = self
506 .messages
507 .iter()
508 .rposition(|message| message.id() == message_id)
509 else {
510 return;
511 };
512 for deleted_message in self.messages.drain(message_ix..) {
513 self.checkpoints_by_message.remove(&deleted_message.id());
514 }
515 cx.notify();
516 }
517
518 pub fn is_turn_end(&self, ix: usize) -> bool {
519 todo!()
520 // if self.messages.is_empty() {
521 // return false;
522 // }
523
524 // if !self.is_generating() && ix == self.messages.len() - 1 {
525 // return true;
526 // }
527
528 // let Some(message) = self.messages.get(ix) else {
529 // return false;
530 // };
531
532 // if message.role != Role::Assistant {
533 // return false;
534 // }
535
536 // self.messages
537 // .get(ix + 1)
538 // .and_then(|message| {
539 // self.message(message.id)
540 // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
541 // })
542 // .unwrap_or(false)
543 }
544
545 pub fn tool_use_limit_reached(&self) -> bool {
546 self.tool_use_limit_reached
547 }
548
549 /// Returns whether any pending tool uses may perform edits
550 pub fn has_pending_edit_tool_uses(&self) -> bool {
551 todo!()
552 }
553
554 // pub fn insert_user_message(
555 // &mut self,
556 // text: impl Into<String>,
557 // loaded_context: ContextLoadResult,
558 // git_checkpoint: Option<GitStoreCheckpoint>,
559 // creases: Vec<MessageCrease>,
560 // cx: &mut Context<Self>,
561 // ) -> AgentThreadMessageId {
562 // todo!("move this logic into send")
563 // if !loaded_context.referenced_buffers.is_empty() {
564 // self.action_log.update(cx, |log, cx| {
565 // for buffer in loaded_context.referenced_buffers {
566 // log.buffer_read(buffer, cx);
567 // }
568 // });
569 // }
570
571 // let message_id = self.insert_message(
572 // Role::User,
573 // vec![MessageSegment::Text(text.into())],
574 // loaded_context.loaded_context,
575 // creases,
576 // false,
577 // cx,
578 // );
579
580 // if let Some(git_checkpoint) = git_checkpoint {
581 // self.pending_checkpoint = Some(ThreadCheckpoint {
582 // message_id,
583 // git_checkpoint,
584 // });
585 // }
586
587 // self.auto_capture_telemetry(cx);
588
589 // message_id
590 // }
591
592 pub fn set_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
593 todo!()
594 }
595
596 pub fn model(&self) -> Option<ConfiguredModel> {
597 todo!()
598 }
599
600 pub fn send(
601 &mut self,
602 message: Vec<AgentThreadUserMessageChunk>,
603 window: &mut Window,
604 cx: &mut Context<Self>,
605 ) {
606 todo!()
607 }
608
609 pub fn resume(&mut self, window: &mut Window, cx: &mut Context<Self>) {
610 todo!()
611 }
612
613 pub fn edit(
614 &mut self,
615 message_id: MessageId,
616 message: Vec<AgentThreadUserMessageChunk>,
617 window: &mut Window,
618 cx: &mut Context<Self>,
619 ) {
620 todo!()
621 }
622
623 pub fn cancel(&mut self, window: &mut Window, cx: &mut Context<Self>) -> bool {
624 todo!()
625 }
626
627 // pub fn insert_invisible_continue_message(
628 // &mut self,
629 // cx: &mut Context<Self>,
630 // ) -> AgentThreadMessageId {
631 // let id = self.insert_message(
632 // Role::User,
633 // vec![MessageSegment::Text("Continue where you left off".into())],
634 // LoadedContext::default(),
635 // vec![],
636 // true,
637 // cx,
638 // );
639 // self.pending_checkpoint = None;
640
641 // id
642 // }
643
644 // pub fn insert_assistant_message(
645 // &mut self,
646 // segments: Vec<MessageSegment>,
647 // cx: &mut Context<Self>,
648 // ) -> AgentThreadMessageId {
649 // self.insert_message(
650 // Role::Assistant,
651 // segments,
652 // LoadedContext::default(),
653 // Vec::new(),
654 // false,
655 // cx,
656 // )
657 // }
658
659 // pub fn insert_message(
660 // &mut self,
661 // role: Role,
662 // segments: Vec<MessageSegment>,
663 // loaded_context: LoadedContext,
664 // creases: Vec<MessageCrease>,
665 // is_hidden: bool,
666 // cx: &mut Context<Self>,
667 // ) -> AgentThreadMessageId {
668 // let id = self.next_message_id.post_inc();
669 // self.messages.push(Message {
670 // id,
671 // role,
672 // segments,
673 // loaded_context,
674 // creases,
675 // is_hidden,
676 // ui_only: false,
677 // });
678 // self.touch_updated_at();
679 // cx.emit(ThreadEvent::MessageAdded(id));
680 // id
681 // }
682
683 // pub fn edit_message(
684 // &mut self,
685 // id: AgentThreadMessageId,
686 // new_role: Role,
687 // new_segments: Vec<MessageSegment>,
688 // creases: Vec<MessageCrease>,
689 // loaded_context: Option<LoadedContext>,
690 // checkpoint: Option<GitStoreCheckpoint>,
691 // cx: &mut Context<Self>,
692 // ) -> bool {
693 // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
694 // return false;
695 // };
696 // message.role = new_role;
697 // message.segments = new_segments;
698 // message.creases = creases;
699 // if let Some(context) = loaded_context {
700 // message.loaded_context = context;
701 // }
702 // if let Some(git_checkpoint) = checkpoint {
703 // self.checkpoints_by_message.insert(
704 // id,
705 // ThreadCheckpoint {
706 // message_id: id,
707 // git_checkpoint,
708 // },
709 // );
710 // }
711 // self.touch_updated_at();
712 // cx.emit(ThreadEvent::MessageEdited(id));
713 // true
714 // }
715
716 /// Returns the representation of this [`Thread`] in a textual form.
717 ///
718 /// This is the representation we use when attaching a thread as context to another thread.
719 pub fn text(&self, cx: &App) -> String {
720 let mut text = String::new();
721
722 for message in &self.messages {
723 match message {
724 Message::User {
725 id,
726 text,
727 creases,
728 loaded_context,
729 } => todo!(),
730 Message::Assistant {
731 id,
732 chunks: segments,
733 } => todo!(),
734 }
735 // text.push_str(match message.role {
736 // language_model::Role::User => "User:",
737 // language_model::Role::Assistant => "Agent:",
738 // language_model::Role::System => "System:",
739 // });
740 // text.push('\n');
741
742 // text.push_str("<think>");
743 // text.push_str(message.thinking.read(cx).source());
744 // text.push_str("</think>");
745 // text.push_str(message.text.read(cx).source());
746
747 text.push('\n');
748 }
749
750 text
751 }
752
753 pub fn used_tools_since_last_user_message(&self) -> bool {
754 todo!()
755 // for message in self.messages.iter().rev() {
756 // if self.tool_use.message_has_tool_results(message.id) {
757 // return true;
758 // } else if message.role == Role::User {
759 // return false;
760 // }
761 // }
762
763 // false
764 }
765
766 pub fn start_generating_detailed_summary_if_needed(
767 &mut self,
768 thread_store: WeakEntity<ThreadStore>,
769 cx: &mut Context<Self>,
770 ) {
771 let Some(last_message_id) = self.messages.last().map(|message| message.id()) else {
772 return;
773 };
774
775 match &*self.detailed_summary_rx.borrow() {
776 DetailedSummaryState::Generating { message_id, .. }
777 | DetailedSummaryState::Generated { message_id, .. }
778 if *message_id == last_message_id =>
779 {
780 // Already up-to-date
781 return;
782 }
783 _ => {}
784 }
785
786 let summary = self.agent_thread.summary();
787
788 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
789 message_id: last_message_id,
790 };
791
792 // Replace the detailed summarization task if there is one, cancelling it. It would probably
793 // be better to allow the old task to complete, but this would require logic for choosing
794 // which result to prefer (the old task could complete after the new one, resulting in a
795 // stale summary).
796 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
797 let Some(summary) = summary.await.log_err() else {
798 thread
799 .update(cx, |thread, _cx| {
800 *thread.detailed_summary_tx.borrow_mut() =
801 DetailedSummaryState::NotGenerated;
802 })
803 .ok()?;
804 return None;
805 };
806
807 thread
808 .update(cx, |thread, _cx| {
809 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
810 text: summary.into(),
811 message_id: last_message_id,
812 };
813 })
814 .ok()?;
815
816 Some(())
817 });
818 }
819
820 pub async fn wait_for_detailed_summary_or_text(
821 this: &Entity<Self>,
822 cx: &mut AsyncApp,
823 ) -> Option<SharedString> {
824 let mut detailed_summary_rx = this
825 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
826 .ok()?;
827 loop {
828 match detailed_summary_rx.recv().await? {
829 DetailedSummaryState::Generating { .. } => {}
830 DetailedSummaryState::NotGenerated => {
831 return this.read_with(cx, |this, cx| this.text(cx).into()).ok();
832 }
833 DetailedSummaryState::Generated { text, .. } => return Some(text),
834 }
835 }
836 }
837
838 pub fn latest_detailed_summary_or_text(&self, cx: &App) -> SharedString {
839 self.detailed_summary_rx
840 .borrow()
841 .text()
842 .unwrap_or_else(|| self.text(cx).into())
843 }
844
845 pub fn is_generating_detailed_summary(&self) -> bool {
846 matches!(
847 &*self.detailed_summary_rx.borrow(),
848 DetailedSummaryState::Generating { .. }
849 )
850 }
851
852 pub fn feedback(&self) -> Option<ThreadFeedback> {
853 self.feedback
854 }
855
856 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
857 self.message_feedback.get(&message_id).copied()
858 }
859
860 pub fn report_message_feedback(
861 &mut self,
862 message_id: MessageId,
863 feedback: ThreadFeedback,
864 cx: &mut Context<Self>,
865 ) -> Task<Result<()>> {
866 todo!()
867 // if self.message_feedback.get(&message_id) == Some(&feedback) {
868 // return Task::ready(Ok(()));
869 // }
870
871 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
872 // let serialized_thread = self.serialize(cx);
873 // let thread_id = self.id().clone();
874 // let client = self.project.read(cx).client();
875
876 // let enabled_tool_names: Vec<String> = self
877 // .profile
878 // .enabled_tools(cx)
879 // .iter()
880 // .map(|tool| tool.name())
881 // .collect();
882
883 // self.message_feedback.insert(message_id, feedback);
884
885 // cx.notify();
886
887 // let message_content = self
888 // .message(message_id)
889 // .map(|msg| msg.to_string())
890 // .unwrap_or_default();
891
892 // cx.background_spawn(async move {
893 // let final_project_snapshot = final_project_snapshot.await;
894 // let serialized_thread = serialized_thread.await?;
895 // let thread_data =
896 // serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
897
898 // let rating = match feedback {
899 // ThreadFeedback::Positive => "positive",
900 // ThreadFeedback::Negative => "negative",
901 // };
902 // telemetry::event!(
903 // "Assistant Thread Rated",
904 // rating,
905 // thread_id,
906 // enabled_tool_names,
907 // message_id = message_id,
908 // message_content,
909 // thread_data,
910 // final_project_snapshot
911 // );
912 // client.telemetry().flush_events().await;
913
914 // Ok(())
915 // })
916 }
917
918 pub fn report_feedback(
919 &mut self,
920 feedback: ThreadFeedback,
921 cx: &mut Context<Self>,
922 ) -> Task<Result<()>> {
923 todo!()
924 // let last_assistant_message_id = self
925 // .messages
926 // .iter()
927 // .rev()
928 // .find(|msg| msg.role == Role::Assistant)
929 // .map(|msg| msg.id);
930
931 // if let Some(message_id) = last_assistant_message_id {
932 // self.report_message_feedback(message_id, feedback, cx)
933 // } else {
934 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
935 // let serialized_thread = self.serialize(cx);
936 // let thread_id = self.id().clone();
937 // let client = self.project.read(cx).client();
938 // self.feedback = Some(feedback);
939 // cx.notify();
940
941 // cx.background_spawn(async move {
942 // let final_project_snapshot = final_project_snapshot.await;
943 // let serialized_thread = serialized_thread.await?;
944 // let thread_data = serde_json::to_value(serialized_thread)
945 // .unwrap_or_else(|_| serde_json::Value::Null);
946
947 // let rating = match feedback {
948 // ThreadFeedback::Positive => "positive",
949 // ThreadFeedback::Negative => "negative",
950 // };
951 // telemetry::event!(
952 // "Assistant Thread Rated",
953 // rating,
954 // thread_id,
955 // thread_data,
956 // final_project_snapshot
957 // );
958 // client.telemetry().flush_events().await;
959
960 // Ok(())
961 // })
962 // }
963 }
964
965 /// Create a snapshot of the current project state including git information and unsaved buffers.
966 fn project_snapshot(
967 project: Entity<Project>,
968 cx: &mut Context<Self>,
969 ) -> Task<Arc<ProjectSnapshot>> {
970 let git_store = project.read(cx).git_store().clone();
971 let worktree_snapshots: Vec<_> = project
972 .read(cx)
973 .visible_worktrees(cx)
974 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
975 .collect();
976
977 cx.spawn(async move |_, cx| {
978 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
979
980 let mut unsaved_buffers = Vec::new();
981 cx.update(|app_cx| {
982 let buffer_store = project.read(app_cx).buffer_store();
983 for buffer_handle in buffer_store.read(app_cx).buffers() {
984 let buffer = buffer_handle.read(app_cx);
985 if buffer.is_dirty() {
986 if let Some(file) = buffer.file() {
987 let path = file.path().to_string_lossy().to_string();
988 unsaved_buffers.push(path);
989 }
990 }
991 }
992 })
993 .ok();
994
995 Arc::new(ProjectSnapshot {
996 worktree_snapshots,
997 unsaved_buffer_paths: unsaved_buffers,
998 timestamp: Utc::now(),
999 })
1000 })
1001 }
1002
1003 fn worktree_snapshot(
1004 worktree: Entity<project::Worktree>,
1005 git_store: Entity<GitStore>,
1006 cx: &App,
1007 ) -> Task<WorktreeSnapshot> {
1008 cx.spawn(async move |cx| {
1009 // Get worktree path and snapshot
1010 let worktree_info = cx.update(|app_cx| {
1011 let worktree = worktree.read(app_cx);
1012 let path = worktree.abs_path().to_string_lossy().to_string();
1013 let snapshot = worktree.snapshot();
1014 (path, snapshot)
1015 });
1016
1017 let Ok((worktree_path, _snapshot)) = worktree_info else {
1018 return WorktreeSnapshot {
1019 worktree_path: String::new(),
1020 git_state: None,
1021 };
1022 };
1023
1024 let git_state = git_store
1025 .update(cx, |git_store, cx| {
1026 git_store
1027 .repositories()
1028 .values()
1029 .find(|repo| {
1030 repo.read(cx)
1031 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1032 .is_some()
1033 })
1034 .cloned()
1035 })
1036 .ok()
1037 .flatten()
1038 .map(|repo| {
1039 repo.update(cx, |repo, _| {
1040 let current_branch =
1041 repo.branch.as_ref().map(|branch| branch.name().to_owned());
1042 repo.send_job(None, |state, _| async move {
1043 let RepositoryState::Local { backend, .. } = state else {
1044 return GitState {
1045 remote_url: None,
1046 head_sha: None,
1047 current_branch,
1048 diff: None,
1049 };
1050 };
1051
1052 let remote_url = backend.remote_url("origin");
1053 let head_sha = backend.head_sha().await;
1054 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1055
1056 GitState {
1057 remote_url,
1058 head_sha,
1059 current_branch,
1060 diff,
1061 }
1062 })
1063 })
1064 });
1065
1066 let git_state = match git_state {
1067 Some(git_state) => match git_state.ok() {
1068 Some(git_state) => git_state.await.ok(),
1069 None => None,
1070 },
1071 None => None,
1072 };
1073
1074 WorktreeSnapshot {
1075 worktree_path,
1076 git_state,
1077 }
1078 })
1079 }
1080
1081 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1082 todo!()
1083 // let mut markdown = Vec::new();
1084
1085 // let summary = self.summary().or_default();
1086 // writeln!(markdown, "# {summary}\n")?;
1087
1088 // for message in self.messages() {
1089 // writeln!(
1090 // markdown,
1091 // "## {role}\n",
1092 // role = match message.role {
1093 // Role::User => "User",
1094 // Role::Assistant => "Agent",
1095 // Role::System => "System",
1096 // }
1097 // )?;
1098
1099 // if !message.loaded_context.text.is_empty() {
1100 // writeln!(markdown, "{}", message.loaded_context.text)?;
1101 // }
1102
1103 // if !message.loaded_context.images.is_empty() {
1104 // writeln!(
1105 // markdown,
1106 // "\n{} images attached as context.\n",
1107 // message.loaded_context.images.len()
1108 // )?;
1109 // }
1110
1111 // for segment in &message.segments {
1112 // match segment {
1113 // MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1114 // MessageSegment::Thinking { text, .. } => {
1115 // writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1116 // }
1117 // MessageSegment::RedactedThinking(_) => {}
1118 // }
1119 // }
1120
1121 // for tool_use in self.tool_uses_for_message(message.id, cx) {
1122 // writeln!(
1123 // markdown,
1124 // "**Use Tool: {} ({})**",
1125 // tool_use.name, tool_use.id
1126 // )?;
1127 // writeln!(markdown, "```json")?;
1128 // writeln!(
1129 // markdown,
1130 // "{}",
1131 // serde_json::to_string_pretty(&tool_use.input)?
1132 // )?;
1133 // writeln!(markdown, "```")?;
1134 // }
1135
1136 // for tool_result in self.tool_results_for_message(message.id) {
1137 // write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
1138 // if tool_result.is_error {
1139 // write!(markdown, " (Error)")?;
1140 // }
1141
1142 // writeln!(markdown, "**\n")?;
1143 // match &tool_result.content {
1144 // LanguageModelToolResultContent::Text(text) => {
1145 // writeln!(markdown, "{text}")?;
1146 // }
1147 // LanguageModelToolResultContent::Image(image) => {
1148 // writeln!(markdown, "", image.source)?;
1149 // }
1150 // }
1151
1152 // if let Some(output) = tool_result.output.as_ref() {
1153 // writeln!(
1154 // markdown,
1155 // "\n\nDebug Output:\n\n```json\n{}\n```\n",
1156 // serde_json::to_string_pretty(output)?
1157 // )?;
1158 // }
1159 // }
1160 // }
1161
1162 // Ok(String::from_utf8_lossy(&markdown).to_string())
1163 }
1164
1165 pub fn keep_edits_in_range(
1166 &mut self,
1167 buffer: Entity<language::Buffer>,
1168 buffer_range: Range<language::Anchor>,
1169 cx: &mut Context<Self>,
1170 ) {
1171 self.action_log.update(cx, |action_log, cx| {
1172 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1173 });
1174 }
1175
1176 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1177 self.action_log
1178 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1179 }
1180
1181 pub fn reject_edits_in_ranges(
1182 &mut self,
1183 buffer: Entity<language::Buffer>,
1184 buffer_ranges: Vec<Range<language::Anchor>>,
1185 cx: &mut Context<Self>,
1186 ) -> Task<Result<()>> {
1187 self.action_log.update(cx, |action_log, cx| {
1188 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1189 })
1190 }
1191
1192 pub fn action_log(&self) -> &Entity<ActionLog> {
1193 &self.action_log
1194 }
1195
1196 pub fn project(&self) -> &Entity<Project> {
1197 &self.project
1198 }
1199
1200 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1201 todo!()
1202 // if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
1203 // return;
1204 // }
1205
1206 // let now = Instant::now();
1207 // if let Some(last) = self.last_auto_capture_at {
1208 // if now.duration_since(last).as_secs() < 10 {
1209 // return;
1210 // }
1211 // }
1212
1213 // self.last_auto_capture_at = Some(now);
1214
1215 // let thread_id = self.id().clone();
1216 // let github_login = self
1217 // .project
1218 // .read(cx)
1219 // .user_store()
1220 // .read(cx)
1221 // .current_user()
1222 // .map(|user| user.github_login.clone());
1223 // let client = self.project.read(cx).client();
1224 // let serialize_task = self.serialize(cx);
1225
1226 // cx.background_executor()
1227 // .spawn(async move {
1228 // if let Ok(serialized_thread) = serialize_task.await {
1229 // if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1230 // telemetry::event!(
1231 // "Agent Thread Auto-Captured",
1232 // thread_id = thread_id.to_string(),
1233 // thread_data = thread_data,
1234 // auto_capture_reason = "tracked_user",
1235 // github_login = github_login
1236 // );
1237
1238 // client.telemetry().flush_events().await;
1239 // }
1240 // }
1241 // })
1242 // .detach();
1243 }
1244
1245 pub fn cumulative_token_usage(&self) -> TokenUsage {
1246 self.cumulative_token_usage
1247 }
1248
1249 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
1250 todo!()
1251 // let Some(model) = self.configured_model.as_ref() else {
1252 // return TotalTokenUsage::default();
1253 // };
1254
1255 // let max = model.model.max_token_count();
1256
1257 // let index = self
1258 // .messages
1259 // .iter()
1260 // .position(|msg| msg.id == message_id)
1261 // .unwrap_or(0);
1262
1263 // if index == 0 {
1264 // return TotalTokenUsage { total: 0, max };
1265 // }
1266
1267 // let token_usage = &self
1268 // .request_token_usage
1269 // .get(index - 1)
1270 // .cloned()
1271 // .unwrap_or_default();
1272
1273 // TotalTokenUsage {
1274 // total: token_usage.total_tokens(),
1275 // max,
1276 // }
1277 }
1278
1279 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
1280 todo!()
1281 // let model = self.configured_model.as_ref()?;
1282
1283 // let max = model.model.max_token_count();
1284
1285 // if let Some(exceeded_error) = &self.exceeded_window_error {
1286 // if model.model.id() == exceeded_error.model_id {
1287 // return Some(TotalTokenUsage {
1288 // total: exceeded_error.token_count,
1289 // max,
1290 // });
1291 // }
1292 // }
1293
1294 // let total = self
1295 // .token_usage_at_last_message()
1296 // .unwrap_or_default()
1297 // .total_tokens();
1298
1299 // Some(TotalTokenUsage { total, max })
1300 }
1301
1302 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1303 self.request_token_usage
1304 .get(self.messages.len().saturating_sub(1))
1305 .or_else(|| self.request_token_usage.last())
1306 .cloned()
1307 }
1308
1309 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
1310 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
1311 self.request_token_usage
1312 .resize(self.messages.len(), placeholder);
1313
1314 if let Some(last) = self.request_token_usage.last_mut() {
1315 *last = token_usage;
1316 }
1317 }
1318
1319 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
1320 self.project.update(cx, |project, cx| {
1321 project.user_store().update(cx, |user_store, cx| {
1322 user_store.update_model_request_usage(
1323 ModelRequestUsage(RequestUsage {
1324 amount: amount as i32,
1325 limit,
1326 }),
1327 cx,
1328 )
1329 })
1330 });
1331 }
1332}
1333
1334#[derive(Debug, Clone, Error)]
1335pub enum ThreadError {
1336 #[error("Payment required")]
1337 PaymentRequired,
1338 #[error("Model request limit reached")]
1339 ModelRequestLimitReached { plan: Plan },
1340 #[error("Message {header}: {message}")]
1341 Message {
1342 header: SharedString,
1343 message: SharedString,
1344 },
1345}
1346
1347#[derive(Debug, Clone)]
1348pub enum ThreadEvent {
1349 ShowError(ThreadError),
1350 StreamedCompletion,
1351 NewRequest,
1352 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1353 MessagesUpdated {
1354 old_range: Range<usize>,
1355 new_length: usize,
1356 },
1357 SummaryGenerated,
1358 SummaryChanged,
1359 CheckpointChanged,
1360 ToolConfirmationNeeded,
1361 ToolUseLimitReached,
1362 CancelEditing,
1363 CompletionCanceled,
1364 ProfileChanged,
1365 RetriesFailed {
1366 message: SharedString,
1367 },
1368}
1369
1370impl EventEmitter<ThreadEvent> for Thread {}
1371
1372/// Resolves tool name conflicts by ensuring all tool names are unique.
1373///
1374/// When multiple tools have the same name, this function applies the following rules:
1375/// 1. Native tools always keep their original name
1376/// 2. Context server tools get prefixed with their server ID and an underscore
1377/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
1378/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
1379///
1380/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
1381fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
1382 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
1383 let mut tool_name = tool.name();
1384 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1385 tool_name
1386 }
1387
1388 const MAX_TOOL_NAME_LENGTH: usize = 64;
1389
1390 let mut duplicated_tool_names = HashSet::default();
1391 let mut seen_tool_names = HashSet::default();
1392 for tool in tools {
1393 let tool_name = resolve_tool_name(tool);
1394 if seen_tool_names.contains(&tool_name) {
1395 debug_assert!(
1396 tool.source() != assistant_tool::ToolSource::Native,
1397 "There are two built-in tools with the same name: {}",
1398 tool_name
1399 );
1400 duplicated_tool_names.insert(tool_name);
1401 } else {
1402 seen_tool_names.insert(tool_name);
1403 }
1404 }
1405
1406 if duplicated_tool_names.is_empty() {
1407 return tools
1408 .into_iter()
1409 .map(|tool| (resolve_tool_name(tool), tool.clone()))
1410 .collect();
1411 }
1412
1413 tools
1414 .into_iter()
1415 .filter_map(|tool| {
1416 let mut tool_name = resolve_tool_name(tool);
1417 if !duplicated_tool_names.contains(&tool_name) {
1418 return Some((tool_name, tool.clone()));
1419 }
1420 match tool.source() {
1421 assistant_tool::ToolSource::Native => {
1422 // Built-in tools always keep their original name
1423 Some((tool_name, tool.clone()))
1424 }
1425 assistant_tool::ToolSource::ContextServer { id } => {
1426 // Context server tools are prefixed with the context server ID, and truncated if necessary
1427 tool_name.insert(0, '_');
1428 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
1429 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
1430 let mut id = id.to_string();
1431 id.truncate(len);
1432 tool_name.insert_str(0, &id);
1433 } else {
1434 tool_name.insert_str(0, &id);
1435 }
1436
1437 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1438
1439 if seen_tool_names.contains(&tool_name) {
1440 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
1441 None
1442 } else {
1443 Some((tool_name, tool.clone()))
1444 }
1445 }
1446 }
1447 })
1448 .collect()
1449}
1450
1451// #[cfg(test)]
1452// mod tests {
1453// use super::*;
1454// use crate::{
1455// context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
1456// };
1457
1458// // Test-specific constants
1459// const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
1460// use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
1461// use assistant_tool::ToolRegistry;
1462// use futures::StreamExt;
1463// use futures::future::BoxFuture;
1464// use futures::stream::BoxStream;
1465// use gpui::TestAppContext;
1466// use icons::IconName;
1467// use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1468// use language_model::{
1469// LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
1470// LanguageModelProviderName, LanguageModelToolChoice,
1471// };
1472// use parking_lot::Mutex;
1473// use project::{FakeFs, Project};
1474// use prompt_store::PromptBuilder;
1475// use serde_json::json;
1476// use settings::{Settings, SettingsStore};
1477// use std::sync::Arc;
1478// use std::time::Duration;
1479// use theme::ThemeSettings;
1480// use util::path;
1481// use workspace::Workspace;
1482
1483// #[gpui::test]
1484// async fn test_message_with_context(cx: &mut TestAppContext) {
1485// init_test_settings(cx);
1486
1487// let project = create_test_project(
1488// cx,
1489// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1490// )
1491// .await;
1492
1493// let (_workspace, _thread_store, thread, context_store, model) =
1494// setup_test_environment(cx, project.clone()).await;
1495
1496// add_file_to_context(&project, &context_store, "test/code.rs", cx)
1497// .await
1498// .unwrap();
1499
1500// let context =
1501// context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
1502// let loaded_context = cx
1503// .update(|cx| load_context(vec![context], &project, &None, cx))
1504// .await;
1505
1506// // Insert user message with context
1507// let message_id = thread.update(cx, |thread, cx| {
1508// thread.insert_user_message(
1509// "Please explain this code",
1510// loaded_context,
1511// None,
1512// Vec::new(),
1513// cx,
1514// )
1515// });
1516
1517// // Check content and context in message object
1518// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1519
1520// // Use different path format strings based on platform for the test
1521// #[cfg(windows)]
1522// let path_part = r"test\code.rs";
1523// #[cfg(not(windows))]
1524// let path_part = "test/code.rs";
1525
1526// let expected_context = format!(
1527// r#"
1528// <context>
1529// The following items were attached by the user. They are up-to-date and don't need to be re-read.
1530
1531// <files>
1532// ```rs {path_part}
1533// fn main() {{
1534// println!("Hello, world!");
1535// }}
1536// ```
1537// </files>
1538// </context>
1539// "#
1540// );
1541
1542// assert_eq!(message.role, Role::User);
1543// assert_eq!(message.segments.len(), 1);
1544// assert_eq!(
1545// message.segments[0],
1546// MessageSegment::Text("Please explain this code".to_string())
1547// );
1548// assert_eq!(message.loaded_context.text, expected_context);
1549
1550// // Check message in request
1551// let request = thread.update(cx, |thread, cx| {
1552// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1553// });
1554
1555// assert_eq!(request.messages.len(), 2);
1556// let expected_full_message = format!("{}Please explain this code", expected_context);
1557// assert_eq!(request.messages[1].string_contents(), expected_full_message);
1558// }
1559
1560// #[gpui::test]
1561// async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
1562// init_test_settings(cx);
1563
1564// let project = create_test_project(
1565// cx,
1566// json!({
1567// "file1.rs": "fn function1() {}\n",
1568// "file2.rs": "fn function2() {}\n",
1569// "file3.rs": "fn function3() {}\n",
1570// "file4.rs": "fn function4() {}\n",
1571// }),
1572// )
1573// .await;
1574
1575// let (_, _thread_store, thread, context_store, model) =
1576// setup_test_environment(cx, project.clone()).await;
1577
1578// // First message with context 1
1579// add_file_to_context(&project, &context_store, "test/file1.rs", cx)
1580// .await
1581// .unwrap();
1582// let new_contexts = context_store.update(cx, |store, cx| {
1583// store.new_context_for_thread(thread.read(cx), None)
1584// });
1585// assert_eq!(new_contexts.len(), 1);
1586// let loaded_context = cx
1587// .update(|cx| load_context(new_contexts, &project, &None, cx))
1588// .await;
1589// let message1_id = thread.update(cx, |thread, cx| {
1590// thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
1591// });
1592
1593// // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
1594// add_file_to_context(&project, &context_store, "test/file2.rs", cx)
1595// .await
1596// .unwrap();
1597// let new_contexts = context_store.update(cx, |store, cx| {
1598// store.new_context_for_thread(thread.read(cx), None)
1599// });
1600// assert_eq!(new_contexts.len(), 1);
1601// let loaded_context = cx
1602// .update(|cx| load_context(new_contexts, &project, &None, cx))
1603// .await;
1604// let message2_id = thread.update(cx, |thread, cx| {
1605// thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
1606// });
1607
1608// // Third message with all three contexts (contexts 1 and 2 should be skipped)
1609// //
1610// add_file_to_context(&project, &context_store, "test/file3.rs", cx)
1611// .await
1612// .unwrap();
1613// let new_contexts = context_store.update(cx, |store, cx| {
1614// store.new_context_for_thread(thread.read(cx), None)
1615// });
1616// assert_eq!(new_contexts.len(), 1);
1617// let loaded_context = cx
1618// .update(|cx| load_context(new_contexts, &project, &None, cx))
1619// .await;
1620// let message3_id = thread.update(cx, |thread, cx| {
1621// thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
1622// });
1623
1624// // Check what contexts are included in each message
1625// let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
1626// (
1627// thread.message(message1_id).unwrap().clone(),
1628// thread.message(message2_id).unwrap().clone(),
1629// thread.message(message3_id).unwrap().clone(),
1630// )
1631// });
1632
1633// // First message should include context 1
1634// assert!(message1.loaded_context.text.contains("file1.rs"));
1635
1636// // Second message should include only context 2 (not 1)
1637// assert!(!message2.loaded_context.text.contains("file1.rs"));
1638// assert!(message2.loaded_context.text.contains("file2.rs"));
1639
1640// // Third message should include only context 3 (not 1 or 2)
1641// assert!(!message3.loaded_context.text.contains("file1.rs"));
1642// assert!(!message3.loaded_context.text.contains("file2.rs"));
1643// assert!(message3.loaded_context.text.contains("file3.rs"));
1644
1645// // Check entire request to make sure all contexts are properly included
1646// let request = thread.update(cx, |thread, cx| {
1647// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1648// });
1649
1650// // The request should contain all 3 messages
1651// assert_eq!(request.messages.len(), 4);
1652
1653// // Check that the contexts are properly formatted in each message
1654// assert!(request.messages[1].string_contents().contains("file1.rs"));
1655// assert!(!request.messages[1].string_contents().contains("file2.rs"));
1656// assert!(!request.messages[1].string_contents().contains("file3.rs"));
1657
1658// assert!(!request.messages[2].string_contents().contains("file1.rs"));
1659// assert!(request.messages[2].string_contents().contains("file2.rs"));
1660// assert!(!request.messages[2].string_contents().contains("file3.rs"));
1661
1662// assert!(!request.messages[3].string_contents().contains("file1.rs"));
1663// assert!(!request.messages[3].string_contents().contains("file2.rs"));
1664// assert!(request.messages[3].string_contents().contains("file3.rs"));
1665
1666// add_file_to_context(&project, &context_store, "test/file4.rs", cx)
1667// .await
1668// .unwrap();
1669// let new_contexts = context_store.update(cx, |store, cx| {
1670// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1671// });
1672// assert_eq!(new_contexts.len(), 3);
1673// let loaded_context = cx
1674// .update(|cx| load_context(new_contexts, &project, &None, cx))
1675// .await
1676// .loaded_context;
1677
1678// assert!(!loaded_context.text.contains("file1.rs"));
1679// assert!(loaded_context.text.contains("file2.rs"));
1680// assert!(loaded_context.text.contains("file3.rs"));
1681// assert!(loaded_context.text.contains("file4.rs"));
1682
1683// let new_contexts = context_store.update(cx, |store, cx| {
1684// // Remove file4.rs
1685// store.remove_context(&loaded_context.contexts[2].handle(), cx);
1686// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1687// });
1688// assert_eq!(new_contexts.len(), 2);
1689// let loaded_context = cx
1690// .update(|cx| load_context(new_contexts, &project, &None, cx))
1691// .await
1692// .loaded_context;
1693
1694// assert!(!loaded_context.text.contains("file1.rs"));
1695// assert!(loaded_context.text.contains("file2.rs"));
1696// assert!(loaded_context.text.contains("file3.rs"));
1697// assert!(!loaded_context.text.contains("file4.rs"));
1698
1699// let new_contexts = context_store.update(cx, |store, cx| {
1700// // Remove file3.rs
1701// store.remove_context(&loaded_context.contexts[1].handle(), cx);
1702// store.new_context_for_thread(thread.read(cx), Some(message2_id))
1703// });
1704// assert_eq!(new_contexts.len(), 1);
1705// let loaded_context = cx
1706// .update(|cx| load_context(new_contexts, &project, &None, cx))
1707// .await
1708// .loaded_context;
1709
1710// assert!(!loaded_context.text.contains("file1.rs"));
1711// assert!(loaded_context.text.contains("file2.rs"));
1712// assert!(!loaded_context.text.contains("file3.rs"));
1713// assert!(!loaded_context.text.contains("file4.rs"));
1714// }
1715
1716// #[gpui::test]
1717// async fn test_message_without_files(cx: &mut TestAppContext) {
1718// init_test_settings(cx);
1719
1720// let project = create_test_project(
1721// cx,
1722// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1723// )
1724// .await;
1725
1726// let (_, _thread_store, thread, _context_store, model) =
1727// setup_test_environment(cx, project.clone()).await;
1728
1729// // Insert user message without any context (empty context vector)
1730// let message_id = thread.update(cx, |thread, cx| {
1731// thread.insert_user_message(
1732// "What is the best way to learn Rust?",
1733// ContextLoadResult::default(),
1734// None,
1735// Vec::new(),
1736// cx,
1737// )
1738// });
1739
1740// // Check content and context in message object
1741// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1742
1743// // Context should be empty when no files are included
1744// assert_eq!(message.role, Role::User);
1745// assert_eq!(message.segments.len(), 1);
1746// assert_eq!(
1747// message.segments[0],
1748// MessageSegment::Text("What is the best way to learn Rust?".to_string())
1749// );
1750// assert_eq!(message.loaded_context.text, "");
1751
1752// // Check message in request
1753// let request = thread.update(cx, |thread, cx| {
1754// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1755// });
1756
1757// assert_eq!(request.messages.len(), 2);
1758// assert_eq!(
1759// request.messages[1].string_contents(),
1760// "What is the best way to learn Rust?"
1761// );
1762
1763// // Add second message, also without context
1764// let message2_id = thread.update(cx, |thread, cx| {
1765// thread.insert_user_message(
1766// "Are there any good books?",
1767// ContextLoadResult::default(),
1768// None,
1769// Vec::new(),
1770// cx,
1771// )
1772// });
1773
1774// let message2 =
1775// thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
1776// assert_eq!(message2.loaded_context.text, "");
1777
1778// // Check that both messages appear in the request
1779// let request = thread.update(cx, |thread, cx| {
1780// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1781// });
1782
1783// assert_eq!(request.messages.len(), 3);
1784// assert_eq!(
1785// request.messages[1].string_contents(),
1786// "What is the best way to learn Rust?"
1787// );
1788// assert_eq!(
1789// request.messages[2].string_contents(),
1790// "Are there any good books?"
1791// );
1792// }
1793
1794// #[gpui::test]
1795// async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
1796// init_test_settings(cx);
1797
1798// let project = create_test_project(
1799// cx,
1800// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1801// )
1802// .await;
1803
1804// let (_workspace, thread_store, thread, _context_store, _model) =
1805// setup_test_environment(cx, project.clone()).await;
1806
1807// // Check that we are starting with the default profile
1808// let profile = cx.read(|cx| thread.read(cx).profile.clone());
1809// let tool_set = cx.read(|cx| thread_store.read(cx).tools());
1810// assert_eq!(
1811// profile,
1812// AgentProfile::new(AgentProfileId::default(), tool_set)
1813// );
1814// }
1815
1816// #[gpui::test]
1817// async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
1818// init_test_settings(cx);
1819
1820// let project = create_test_project(
1821// cx,
1822// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1823// )
1824// .await;
1825
1826// let (_workspace, thread_store, thread, _context_store, _model) =
1827// setup_test_environment(cx, project.clone()).await;
1828
1829// // Profile gets serialized with default values
1830// let serialized = thread
1831// .update(cx, |thread, cx| thread.serialize(cx))
1832// .await
1833// .unwrap();
1834
1835// assert_eq!(serialized.profile, Some(AgentProfileId::default()));
1836
1837// let deserialized = cx.update(|cx| {
1838// thread.update(cx, |thread, cx| {
1839// Thread::deserialize(
1840// thread.id.clone(),
1841// serialized,
1842// thread.project.clone(),
1843// thread.tools.clone(),
1844// thread.prompt_builder.clone(),
1845// thread.project_context.clone(),
1846// None,
1847// cx,
1848// )
1849// })
1850// });
1851// let tool_set = cx.read(|cx| thread_store.read(cx).tools());
1852
1853// assert_eq!(
1854// deserialized.profile,
1855// AgentProfile::new(AgentProfileId::default(), tool_set)
1856// );
1857// }
1858
1859// #[gpui::test]
1860// async fn test_temperature_setting(cx: &mut TestAppContext) {
1861// init_test_settings(cx);
1862
1863// let project = create_test_project(
1864// cx,
1865// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1866// )
1867// .await;
1868
1869// let (_workspace, _thread_store, thread, _context_store, model) =
1870// setup_test_environment(cx, project.clone()).await;
1871
1872// // Both model and provider
1873// cx.update(|cx| {
1874// AgentSettings::override_global(
1875// AgentSettings {
1876// model_parameters: vec![LanguageModelParameters {
1877// provider: Some(model.provider_id().0.to_string().into()),
1878// model: Some(model.id().0.clone()),
1879// temperature: Some(0.66),
1880// }],
1881// ..AgentSettings::get_global(cx).clone()
1882// },
1883// cx,
1884// );
1885// });
1886
1887// let request = thread.update(cx, |thread, cx| {
1888// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1889// });
1890// assert_eq!(request.temperature, Some(0.66));
1891
1892// // Only model
1893// cx.update(|cx| {
1894// AgentSettings::override_global(
1895// AgentSettings {
1896// model_parameters: vec![LanguageModelParameters {
1897// provider: None,
1898// model: Some(model.id().0.clone()),
1899// temperature: Some(0.66),
1900// }],
1901// ..AgentSettings::get_global(cx).clone()
1902// },
1903// cx,
1904// );
1905// });
1906
1907// let request = thread.update(cx, |thread, cx| {
1908// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1909// });
1910// assert_eq!(request.temperature, Some(0.66));
1911
1912// // Only provider
1913// cx.update(|cx| {
1914// AgentSettings::override_global(
1915// AgentSettings {
1916// model_parameters: vec![LanguageModelParameters {
1917// provider: Some(model.provider_id().0.to_string().into()),
1918// model: None,
1919// temperature: Some(0.66),
1920// }],
1921// ..AgentSettings::get_global(cx).clone()
1922// },
1923// cx,
1924// );
1925// });
1926
1927// let request = thread.update(cx, |thread, cx| {
1928// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1929// });
1930// assert_eq!(request.temperature, Some(0.66));
1931
1932// // Same model name, different provider
1933// cx.update(|cx| {
1934// AgentSettings::override_global(
1935// AgentSettings {
1936// model_parameters: vec![LanguageModelParameters {
1937// provider: Some("anthropic".into()),
1938// model: Some(model.id().0.clone()),
1939// temperature: Some(0.66),
1940// }],
1941// ..AgentSettings::get_global(cx).clone()
1942// },
1943// cx,
1944// );
1945// });
1946
1947// let request = thread.update(cx, |thread, cx| {
1948// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
1949// });
1950// assert_eq!(request.temperature, None);
1951// }
1952
1953// #[gpui::test]
1954// async fn test_thread_summary(cx: &mut TestAppContext) {
1955// init_test_settings(cx);
1956
1957// let project = create_test_project(cx, json!({})).await;
1958
1959// let (_, _thread_store, thread, _context_store, model) =
1960// setup_test_environment(cx, project.clone()).await;
1961
1962// // Initial state should be pending
1963// thread.read_with(cx, |thread, _| {
1964// assert!(matches!(thread.summary(), ThreadSummary::Pending));
1965// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
1966// });
1967
1968// // Manually setting the summary should not be allowed in this state
1969// thread.update(cx, |thread, cx| {
1970// thread.set_summary("This should not work", cx);
1971// });
1972
1973// thread.read_with(cx, |thread, _| {
1974// assert!(matches!(thread.summary(), ThreadSummary::Pending));
1975// });
1976
1977// // Send a message
1978// thread.update(cx, |thread, cx| {
1979// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
1980// thread.send_to_model(
1981// model.clone(),
1982// CompletionIntent::ThreadSummarization,
1983// None,
1984// cx,
1985// );
1986// });
1987
1988// let fake_model = model.as_fake();
1989// simulate_successful_response(&fake_model, cx);
1990
1991// // Should start generating summary when there are >= 2 messages
1992// thread.read_with(cx, |thread, _| {
1993// assert_eq!(*thread.summary(), ThreadSummary::Generating);
1994// });
1995
1996// // Should not be able to set the summary while generating
1997// thread.update(cx, |thread, cx| {
1998// thread.set_summary("This should not work either", cx);
1999// });
2000
2001// thread.read_with(cx, |thread, _| {
2002// assert!(matches!(thread.summary(), ThreadSummary::Generating));
2003// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
2004// });
2005
2006// cx.run_until_parked();
2007// fake_model.stream_last_completion_response("Brief");
2008// fake_model.stream_last_completion_response(" Introduction");
2009// fake_model.end_last_completion_stream();
2010// cx.run_until_parked();
2011
2012// // Summary should be set
2013// thread.read_with(cx, |thread, _| {
2014// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2015// assert_eq!(thread.summary().or_default(), "Brief Introduction");
2016// });
2017
2018// // Now we should be able to set a summary
2019// thread.update(cx, |thread, cx| {
2020// thread.set_summary("Brief Intro", cx);
2021// });
2022
2023// thread.read_with(cx, |thread, _| {
2024// assert_eq!(thread.summary().or_default(), "Brief Intro");
2025// });
2026
2027// // Test setting an empty summary (should default to DEFAULT)
2028// thread.update(cx, |thread, cx| {
2029// thread.set_summary("", cx);
2030// });
2031
2032// thread.read_with(cx, |thread, _| {
2033// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2034// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
2035// });
2036// }
2037
2038// #[gpui::test]
2039// async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
2040// init_test_settings(cx);
2041
2042// let project = create_test_project(cx, json!({})).await;
2043
2044// let (_, _thread_store, thread, _context_store, model) =
2045// setup_test_environment(cx, project.clone()).await;
2046
2047// test_summarize_error(&model, &thread, cx);
2048
2049// // Now we should be able to set a summary
2050// thread.update(cx, |thread, cx| {
2051// thread.set_summary("Brief Intro", cx);
2052// });
2053
2054// thread.read_with(cx, |thread, _| {
2055// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2056// assert_eq!(thread.summary().or_default(), "Brief Intro");
2057// });
2058// }
2059
2060// #[gpui::test]
2061// async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
2062// init_test_settings(cx);
2063
2064// let project = create_test_project(cx, json!({})).await;
2065
2066// let (_, _thread_store, thread, _context_store, model) =
2067// setup_test_environment(cx, project.clone()).await;
2068
2069// test_summarize_error(&model, &thread, cx);
2070
2071// // Sending another message should not trigger another summarize request
2072// thread.update(cx, |thread, cx| {
2073// thread.insert_user_message(
2074// "How are you?",
2075// ContextLoadResult::default(),
2076// None,
2077// vec![],
2078// cx,
2079// );
2080// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2081// });
2082
2083// let fake_model = model.as_fake();
2084// simulate_successful_response(&fake_model, cx);
2085
2086// thread.read_with(cx, |thread, _| {
2087// // State is still Error, not Generating
2088// assert!(matches!(thread.summary(), ThreadSummary::Error));
2089// });
2090
2091// // But the summarize request can be invoked manually
2092// thread.update(cx, |thread, cx| {
2093// thread.summarize(cx);
2094// });
2095
2096// thread.read_with(cx, |thread, _| {
2097// assert!(matches!(thread.summary(), ThreadSummary::Generating));
2098// });
2099
2100// cx.run_until_parked();
2101// fake_model.stream_last_completion_response("A successful summary");
2102// fake_model.end_last_completion_stream();
2103// cx.run_until_parked();
2104
2105// thread.read_with(cx, |thread, _| {
2106// assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
2107// assert_eq!(thread.summary().or_default(), "A successful summary");
2108// });
2109// }
2110
2111// #[gpui::test]
2112// fn test_resolve_tool_name_conflicts() {
2113// use assistant_tool::{Tool, ToolSource};
2114
2115// assert_resolve_tool_name_conflicts(
2116// vec![
2117// TestTool::new("tool1", ToolSource::Native),
2118// TestTool::new("tool2", ToolSource::Native),
2119// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2120// ],
2121// vec!["tool1", "tool2", "tool3"],
2122// );
2123
2124// assert_resolve_tool_name_conflicts(
2125// vec![
2126// TestTool::new("tool1", ToolSource::Native),
2127// TestTool::new("tool2", ToolSource::Native),
2128// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2129// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
2130// ],
2131// vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
2132// );
2133
2134// assert_resolve_tool_name_conflicts(
2135// vec![
2136// TestTool::new("tool1", ToolSource::Native),
2137// TestTool::new("tool2", ToolSource::Native),
2138// TestTool::new("tool3", ToolSource::Native),
2139// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
2140// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
2141// ],
2142// vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
2143// );
2144
2145// // Test that tool with very long name is always truncated
2146// assert_resolve_tool_name_conflicts(
2147// vec![TestTool::new(
2148// "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
2149// ToolSource::Native,
2150// )],
2151// vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
2152// );
2153
2154// // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
2155// assert_resolve_tool_name_conflicts(
2156// vec![
2157// TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
2158// TestTool::new(
2159// "tool-with-very-very-very-long-name",
2160// ToolSource::ContextServer {
2161// id: "mcp-with-very-very-very-long-name".into(),
2162// },
2163// ),
2164// ],
2165// vec![
2166// "tool-with-very-very-very-long-name",
2167// "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
2168// ],
2169// );
2170
2171// fn assert_resolve_tool_name_conflicts(
2172// tools: Vec<TestTool>,
2173// expected: Vec<impl Into<String>>,
2174// ) {
2175// let tools: Vec<Arc<dyn Tool>> = tools
2176// .into_iter()
2177// .map(|t| Arc::new(t) as Arc<dyn Tool>)
2178// .collect();
2179// let tools = resolve_tool_name_conflicts(&tools);
2180// assert_eq!(tools.len(), expected.len());
2181// for (i, expected_name) in expected.into_iter().enumerate() {
2182// let expected_name = expected_name.into();
2183// let actual_name = &tools[i].0;
2184// assert_eq!(
2185// actual_name, &expected_name,
2186// "Expected '{}' got '{}' at index {}",
2187// expected_name, actual_name, i
2188// );
2189// }
2190// }
2191
2192// struct TestTool {
2193// name: String,
2194// source: ToolSource,
2195// }
2196
2197// impl TestTool {
2198// fn new(name: impl Into<String>, source: ToolSource) -> Self {
2199// Self {
2200// name: name.into(),
2201// source,
2202// }
2203// }
2204// }
2205
2206// impl Tool for TestTool {
2207// fn name(&self) -> String {
2208// self.name.clone()
2209// }
2210
2211// fn icon(&self) -> IconName {
2212// IconName::Ai
2213// }
2214
2215// fn may_perform_edits(&self) -> bool {
2216// false
2217// }
2218
2219// fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
2220// true
2221// }
2222
2223// fn source(&self) -> ToolSource {
2224// self.source.clone()
2225// }
2226
2227// fn description(&self) -> String {
2228// "Test tool".to_string()
2229// }
2230
2231// fn ui_text(&self, _input: &serde_json::Value) -> String {
2232// "Test tool".to_string()
2233// }
2234
2235// fn run(
2236// self: Arc<Self>,
2237// _input: serde_json::Value,
2238// _request: Arc<LanguageModelRequest>,
2239// _project: Entity<Project>,
2240// _action_log: Entity<ActionLog>,
2241// _model: Arc<dyn LanguageModel>,
2242// _window: Option<AnyWindowHandle>,
2243// _cx: &mut App,
2244// ) -> assistant_tool::ToolResult {
2245// assistant_tool::ToolResult {
2246// output: Task::ready(Err(anyhow::anyhow!("No content"))),
2247// card: None,
2248// }
2249// }
2250// }
2251// }
2252
2253// // Helper to create a model that returns errors
2254// enum TestError {
2255// Overloaded,
2256// InternalServerError,
2257// }
2258
2259// struct ErrorInjector {
2260// inner: Arc<FakeLanguageModel>,
2261// error_type: TestError,
2262// }
2263
2264// impl ErrorInjector {
2265// fn new(error_type: TestError) -> Self {
2266// Self {
2267// inner: Arc::new(FakeLanguageModel::default()),
2268// error_type,
2269// }
2270// }
2271// }
2272
2273// impl LanguageModel for ErrorInjector {
2274// fn id(&self) -> LanguageModelId {
2275// self.inner.id()
2276// }
2277
2278// fn name(&self) -> LanguageModelName {
2279// self.inner.name()
2280// }
2281
2282// fn provider_id(&self) -> LanguageModelProviderId {
2283// self.inner.provider_id()
2284// }
2285
2286// fn provider_name(&self) -> LanguageModelProviderName {
2287// self.inner.provider_name()
2288// }
2289
2290// fn supports_tools(&self) -> bool {
2291// self.inner.supports_tools()
2292// }
2293
2294// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2295// self.inner.supports_tool_choice(choice)
2296// }
2297
2298// fn supports_images(&self) -> bool {
2299// self.inner.supports_images()
2300// }
2301
2302// fn telemetry_id(&self) -> String {
2303// self.inner.telemetry_id()
2304// }
2305
2306// fn max_token_count(&self) -> u64 {
2307// self.inner.max_token_count()
2308// }
2309
2310// fn count_tokens(
2311// &self,
2312// request: LanguageModelRequest,
2313// cx: &App,
2314// ) -> BoxFuture<'static, Result<u64>> {
2315// self.inner.count_tokens(request, cx)
2316// }
2317
2318// fn stream_completion(
2319// &self,
2320// _request: LanguageModelRequest,
2321// _cx: &AsyncApp,
2322// ) -> BoxFuture<
2323// 'static,
2324// Result<
2325// BoxStream<
2326// 'static,
2327// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
2328// >,
2329// LanguageModelCompletionError,
2330// >,
2331// > {
2332// let error = match self.error_type {
2333// TestError::Overloaded => LanguageModelCompletionError::Overloaded,
2334// TestError::InternalServerError => {
2335// LanguageModelCompletionError::ApiInternalServerError
2336// }
2337// };
2338// async move {
2339// let stream = futures::stream::once(async move { Err(error) });
2340// Ok(stream.boxed())
2341// }
2342// .boxed()
2343// }
2344
2345// fn as_fake(&self) -> &FakeLanguageModel {
2346// &self.inner
2347// }
2348// }
2349
2350// #[gpui::test]
2351// async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
2352// init_test_settings(cx);
2353
2354// let project = create_test_project(cx, json!({})).await;
2355// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2356
2357// // Create model that returns overloaded error
2358// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2359
2360// // Insert a user message
2361// thread.update(cx, |thread, cx| {
2362// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2363// });
2364
2365// // Start completion
2366// thread.update(cx, |thread, cx| {
2367// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2368// });
2369
2370// cx.run_until_parked();
2371
2372// thread.read_with(cx, |thread, _| {
2373// assert!(thread.retry_state.is_some(), "Should have retry state");
2374// let retry_state = thread.retry_state.as_ref().unwrap();
2375// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2376// assert_eq!(
2377// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2378// "Should have default max attempts"
2379// );
2380// });
2381
2382// // Check that a retry message was added
2383// thread.read_with(cx, |thread, _| {
2384// let mut messages = thread.messages();
2385// assert!(
2386// messages.any(|msg| {
2387// msg.role == Role::System
2388// && msg.ui_only
2389// && msg.segments.iter().any(|seg| {
2390// if let MessageSegment::Text(text) = seg {
2391// text.contains("overloaded")
2392// && text
2393// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
2394// } else {
2395// false
2396// }
2397// })
2398// }),
2399// "Should have added a system retry message"
2400// );
2401// });
2402
2403// let retry_count = thread.update(cx, |thread, _| {
2404// thread
2405// .messages
2406// .iter()
2407// .filter(|m| {
2408// m.ui_only
2409// && m.segments.iter().any(|s| {
2410// if let MessageSegment::Text(text) = s {
2411// text.contains("Retrying") && text.contains("seconds")
2412// } else {
2413// false
2414// }
2415// })
2416// })
2417// .count()
2418// });
2419
2420// assert_eq!(retry_count, 1, "Should have one retry message");
2421// }
2422
2423// #[gpui::test]
2424// async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
2425// init_test_settings(cx);
2426
2427// let project = create_test_project(cx, json!({})).await;
2428// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2429
2430// // Create model that returns internal server error
2431// let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
2432
2433// // Insert a user message
2434// thread.update(cx, |thread, cx| {
2435// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2436// });
2437
2438// // Start completion
2439// thread.update(cx, |thread, cx| {
2440// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2441// });
2442
2443// cx.run_until_parked();
2444
2445// // Check retry state on thread
2446// thread.read_with(cx, |thread, _| {
2447// assert!(thread.retry_state.is_some(), "Should have retry state");
2448// let retry_state = thread.retry_state.as_ref().unwrap();
2449// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2450// assert_eq!(
2451// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2452// "Should have correct max attempts"
2453// );
2454// });
2455
2456// // Check that a retry message was added with provider name
2457// thread.read_with(cx, |thread, _| {
2458// let mut messages = thread.messages();
2459// assert!(
2460// messages.any(|msg| {
2461// msg.role == Role::System
2462// && msg.ui_only
2463// && msg.segments.iter().any(|seg| {
2464// if let MessageSegment::Text(text) = seg {
2465// text.contains("internal")
2466// && text.contains("Fake")
2467// && text
2468// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
2469// } else {
2470// false
2471// }
2472// })
2473// }),
2474// "Should have added a system retry message with provider name"
2475// );
2476// });
2477
2478// // Count retry messages
2479// let retry_count = thread.update(cx, |thread, _| {
2480// thread
2481// .messages
2482// .iter()
2483// .filter(|m| {
2484// m.ui_only
2485// && m.segments.iter().any(|s| {
2486// if let MessageSegment::Text(text) = s {
2487// text.contains("Retrying") && text.contains("seconds")
2488// } else {
2489// false
2490// }
2491// })
2492// })
2493// .count()
2494// });
2495
2496// assert_eq!(retry_count, 1, "Should have one retry message");
2497// }
2498
2499// #[gpui::test]
2500// async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
2501// init_test_settings(cx);
2502
2503// let project = create_test_project(cx, json!({})).await;
2504// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2505
2506// // Create model that returns overloaded error
2507// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2508
2509// // Insert a user message
2510// thread.update(cx, |thread, cx| {
2511// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2512// });
2513
2514// // Track retry events and completion count
2515// // Track completion events
2516// let completion_count = Arc::new(Mutex::new(0));
2517// let completion_count_clone = completion_count.clone();
2518
2519// let _subscription = thread.update(cx, |_, cx| {
2520// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2521// if let ThreadEvent::NewRequest = event {
2522// *completion_count_clone.lock() += 1;
2523// }
2524// })
2525// });
2526
2527// // First attempt
2528// thread.update(cx, |thread, cx| {
2529// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2530// });
2531// cx.run_until_parked();
2532
2533// // Should have scheduled first retry - count retry messages
2534// let retry_count = thread.update(cx, |thread, _| {
2535// thread
2536// .messages
2537// .iter()
2538// .filter(|m| {
2539// m.ui_only
2540// && m.segments.iter().any(|s| {
2541// if let MessageSegment::Text(text) = s {
2542// text.contains("Retrying") && text.contains("seconds")
2543// } else {
2544// false
2545// }
2546// })
2547// })
2548// .count()
2549// });
2550// assert_eq!(retry_count, 1, "Should have scheduled first retry");
2551
2552// // Check retry state
2553// thread.read_with(cx, |thread, _| {
2554// assert!(thread.retry_state.is_some(), "Should have retry state");
2555// let retry_state = thread.retry_state.as_ref().unwrap();
2556// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
2557// });
2558
2559// // Advance clock for first retry
2560// cx.executor()
2561// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
2562// cx.run_until_parked();
2563
2564// // Should have scheduled second retry - count retry messages
2565// let retry_count = thread.update(cx, |thread, _| {
2566// thread
2567// .messages
2568// .iter()
2569// .filter(|m| {
2570// m.ui_only
2571// && m.segments.iter().any(|s| {
2572// if let MessageSegment::Text(text) = s {
2573// text.contains("Retrying") && text.contains("seconds")
2574// } else {
2575// false
2576// }
2577// })
2578// })
2579// .count()
2580// });
2581// assert_eq!(retry_count, 2, "Should have scheduled second retry");
2582
2583// // Check retry state updated
2584// thread.read_with(cx, |thread, _| {
2585// assert!(thread.retry_state.is_some(), "Should have retry state");
2586// let retry_state = thread.retry_state.as_ref().unwrap();
2587// assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
2588// assert_eq!(
2589// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2590// "Should have correct max attempts"
2591// );
2592// });
2593
2594// // Advance clock for second retry (exponential backoff)
2595// cx.executor()
2596// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
2597// cx.run_until_parked();
2598
2599// // Should have scheduled third retry
2600// // Count all retry messages now
2601// let retry_count = thread.update(cx, |thread, _| {
2602// thread
2603// .messages
2604// .iter()
2605// .filter(|m| {
2606// m.ui_only
2607// && m.segments.iter().any(|s| {
2608// if let MessageSegment::Text(text) = s {
2609// text.contains("Retrying") && text.contains("seconds")
2610// } else {
2611// false
2612// }
2613// })
2614// })
2615// .count()
2616// });
2617// assert_eq!(
2618// retry_count, MAX_RETRY_ATTEMPTS as usize,
2619// "Should have scheduled third retry"
2620// );
2621
2622// // Check retry state updated
2623// thread.read_with(cx, |thread, _| {
2624// assert!(thread.retry_state.is_some(), "Should have retry state");
2625// let retry_state = thread.retry_state.as_ref().unwrap();
2626// assert_eq!(
2627// retry_state.attempt, MAX_RETRY_ATTEMPTS,
2628// "Should be at max retry attempt"
2629// );
2630// assert_eq!(
2631// retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
2632// "Should have correct max attempts"
2633// );
2634// });
2635
2636// // Advance clock for third retry (exponential backoff)
2637// cx.executor()
2638// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
2639// cx.run_until_parked();
2640
2641// // No more retries should be scheduled after clock was advanced.
2642// let retry_count = thread.update(cx, |thread, _| {
2643// thread
2644// .messages
2645// .iter()
2646// .filter(|m| {
2647// m.ui_only
2648// && m.segments.iter().any(|s| {
2649// if let MessageSegment::Text(text) = s {
2650// text.contains("Retrying") && text.contains("seconds")
2651// } else {
2652// false
2653// }
2654// })
2655// })
2656// .count()
2657// });
2658// assert_eq!(
2659// retry_count, MAX_RETRY_ATTEMPTS as usize,
2660// "Should not exceed max retries"
2661// );
2662
2663// // Final completion count should be initial + max retries
2664// assert_eq!(
2665// *completion_count.lock(),
2666// (MAX_RETRY_ATTEMPTS + 1) as usize,
2667// "Should have made initial + max retry attempts"
2668// );
2669// }
2670
2671// #[gpui::test]
2672// async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
2673// init_test_settings(cx);
2674
2675// let project = create_test_project(cx, json!({})).await;
2676// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2677
2678// // Create model that returns overloaded error
2679// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
2680
2681// // Insert a user message
2682// thread.update(cx, |thread, cx| {
2683// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2684// });
2685
2686// // Track events
2687// let retries_failed = Arc::new(Mutex::new(false));
2688// let retries_failed_clone = retries_failed.clone();
2689
2690// let _subscription = thread.update(cx, |_, cx| {
2691// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2692// if let ThreadEvent::RetriesFailed { .. } = event {
2693// *retries_failed_clone.lock() = true;
2694// }
2695// })
2696// });
2697
2698// // Start initial completion
2699// thread.update(cx, |thread, cx| {
2700// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2701// });
2702// cx.run_until_parked();
2703
2704// // Advance through all retries
2705// for i in 0..MAX_RETRY_ATTEMPTS {
2706// let delay = if i == 0 {
2707// BASE_RETRY_DELAY_SECS
2708// } else {
2709// BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
2710// };
2711// cx.executor().advance_clock(Duration::from_secs(delay));
2712// cx.run_until_parked();
2713// }
2714
2715// // After the 3rd retry is scheduled, we need to wait for it to execute and fail
2716// // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
2717// let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
2718// cx.executor()
2719// .advance_clock(Duration::from_secs(final_delay));
2720// cx.run_until_parked();
2721
2722// let retry_count = thread.update(cx, |thread, _| {
2723// thread
2724// .messages
2725// .iter()
2726// .filter(|m| {
2727// m.ui_only
2728// && m.segments.iter().any(|s| {
2729// if let MessageSegment::Text(text) = s {
2730// text.contains("Retrying") && text.contains("seconds")
2731// } else {
2732// false
2733// }
2734// })
2735// })
2736// .count()
2737// });
2738
2739// // After max retries, should emit RetriesFailed event
2740// assert_eq!(
2741// retry_count, MAX_RETRY_ATTEMPTS as usize,
2742// "Should have attempted max retries"
2743// );
2744// assert!(
2745// *retries_failed.lock(),
2746// "Should emit RetriesFailed event after max retries exceeded"
2747// );
2748
2749// // Retry state should be cleared
2750// thread.read_with(cx, |thread, _| {
2751// assert!(
2752// thread.retry_state.is_none(),
2753// "Retry state should be cleared after max retries"
2754// );
2755
2756// // Verify we have the expected number of retry messages
2757// let retry_messages = thread
2758// .messages
2759// .iter()
2760// .filter(|msg| msg.ui_only && msg.role == Role::System)
2761// .count();
2762// assert_eq!(
2763// retry_messages, MAX_RETRY_ATTEMPTS as usize,
2764// "Should have one retry message per attempt"
2765// );
2766// });
2767// }
2768
2769// #[gpui::test]
2770// async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
2771// init_test_settings(cx);
2772
2773// let project = create_test_project(cx, json!({})).await;
2774// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2775
2776// // We'll use a wrapper to switch behavior after first failure
2777// struct RetryTestModel {
2778// inner: Arc<FakeLanguageModel>,
2779// failed_once: Arc<Mutex<bool>>,
2780// }
2781
2782// impl LanguageModel for RetryTestModel {
2783// fn id(&self) -> LanguageModelId {
2784// self.inner.id()
2785// }
2786
2787// fn name(&self) -> LanguageModelName {
2788// self.inner.name()
2789// }
2790
2791// fn provider_id(&self) -> LanguageModelProviderId {
2792// self.inner.provider_id()
2793// }
2794
2795// fn provider_name(&self) -> LanguageModelProviderName {
2796// self.inner.provider_name()
2797// }
2798
2799// fn supports_tools(&self) -> bool {
2800// self.inner.supports_tools()
2801// }
2802
2803// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2804// self.inner.supports_tool_choice(choice)
2805// }
2806
2807// fn supports_images(&self) -> bool {
2808// self.inner.supports_images()
2809// }
2810
2811// fn telemetry_id(&self) -> String {
2812// self.inner.telemetry_id()
2813// }
2814
2815// fn max_token_count(&self) -> u64 {
2816// self.inner.max_token_count()
2817// }
2818
2819// fn count_tokens(
2820// &self,
2821// request: LanguageModelRequest,
2822// cx: &App,
2823// ) -> BoxFuture<'static, Result<u64>> {
2824// self.inner.count_tokens(request, cx)
2825// }
2826
2827// fn stream_completion(
2828// &self,
2829// request: LanguageModelRequest,
2830// cx: &AsyncApp,
2831// ) -> BoxFuture<
2832// 'static,
2833// Result<
2834// BoxStream<
2835// 'static,
2836// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
2837// >,
2838// LanguageModelCompletionError,
2839// >,
2840// > {
2841// if !*self.failed_once.lock() {
2842// *self.failed_once.lock() = true;
2843// // Return error on first attempt
2844// let stream = futures::stream::once(async move {
2845// Err(LanguageModelCompletionError::Overloaded)
2846// });
2847// async move { Ok(stream.boxed()) }.boxed()
2848// } else {
2849// // Succeed on retry
2850// self.inner.stream_completion(request, cx)
2851// }
2852// }
2853
2854// fn as_fake(&self) -> &FakeLanguageModel {
2855// &self.inner
2856// }
2857// }
2858
2859// let model = Arc::new(RetryTestModel {
2860// inner: Arc::new(FakeLanguageModel::default()),
2861// failed_once: Arc::new(Mutex::new(false)),
2862// });
2863
2864// // Insert a user message
2865// thread.update(cx, |thread, cx| {
2866// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
2867// });
2868
2869// // Track message deletions
2870// // Track when retry completes successfully
2871// let retry_completed = Arc::new(Mutex::new(false));
2872// let retry_completed_clone = retry_completed.clone();
2873
2874// let _subscription = thread.update(cx, |_, cx| {
2875// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
2876// if let ThreadEvent::StreamedCompletion = event {
2877// *retry_completed_clone.lock() = true;
2878// }
2879// })
2880// });
2881
2882// // Start completion
2883// thread.update(cx, |thread, cx| {
2884// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
2885// });
2886// cx.run_until_parked();
2887
2888// // Get the retry message ID
2889// let retry_message_id = thread.read_with(cx, |thread, _| {
2890// thread
2891// .messages()
2892// .find(|msg| msg.role == Role::System && msg.ui_only)
2893// .map(|msg| msg.id)
2894// .expect("Should have a retry message")
2895// });
2896
2897// // Wait for retry
2898// cx.executor()
2899// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
2900// cx.run_until_parked();
2901
2902// // Stream some successful content
2903// let fake_model = model.as_fake();
2904// // After the retry, there should be a new pending completion
2905// let pending = fake_model.pending_completions();
2906// assert!(
2907// !pending.is_empty(),
2908// "Should have a pending completion after retry"
2909// );
2910// fake_model.stream_completion_response(&pending[0], "Success!");
2911// fake_model.end_completion_stream(&pending[0]);
2912// cx.run_until_parked();
2913
2914// // Check that the retry completed successfully
2915// assert!(
2916// *retry_completed.lock(),
2917// "Retry should have completed successfully"
2918// );
2919
2920// // Retry message should still exist but be marked as ui_only
2921// thread.read_with(cx, |thread, _| {
2922// let retry_msg = thread
2923// .message(retry_message_id)
2924// .expect("Retry message should still exist");
2925// assert!(retry_msg.ui_only, "Retry message should be ui_only");
2926// assert_eq!(
2927// retry_msg.role,
2928// Role::System,
2929// "Retry message should have System role"
2930// );
2931// });
2932// }
2933
2934// #[gpui::test]
2935// async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
2936// init_test_settings(cx);
2937
2938// let project = create_test_project(cx, json!({})).await;
2939// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
2940
2941// // Create a model that fails once then succeeds
2942// struct FailOnceModel {
2943// inner: Arc<FakeLanguageModel>,
2944// failed_once: Arc<Mutex<bool>>,
2945// }
2946
2947// impl LanguageModel for FailOnceModel {
2948// fn id(&self) -> LanguageModelId {
2949// self.inner.id()
2950// }
2951
2952// fn name(&self) -> LanguageModelName {
2953// self.inner.name()
2954// }
2955
2956// fn provider_id(&self) -> LanguageModelProviderId {
2957// self.inner.provider_id()
2958// }
2959
2960// fn provider_name(&self) -> LanguageModelProviderName {
2961// self.inner.provider_name()
2962// }
2963
2964// fn supports_tools(&self) -> bool {
2965// self.inner.supports_tools()
2966// }
2967
2968// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
2969// self.inner.supports_tool_choice(choice)
2970// }
2971
2972// fn supports_images(&self) -> bool {
2973// self.inner.supports_images()
2974// }
2975
2976// fn telemetry_id(&self) -> String {
2977// self.inner.telemetry_id()
2978// }
2979
2980// fn max_token_count(&self) -> u64 {
2981// self.inner.max_token_count()
2982// }
2983
2984// fn count_tokens(
2985// &self,
2986// request: LanguageModelRequest,
2987// cx: &App,
2988// ) -> BoxFuture<'static, Result<u64>> {
2989// self.inner.count_tokens(request, cx)
2990// }
2991
2992// fn stream_completion(
2993// &self,
2994// request: LanguageModelRequest,
2995// cx: &AsyncApp,
2996// ) -> BoxFuture<
2997// 'static,
2998// Result<
2999// BoxStream<
3000// 'static,
3001// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
3002// >,
3003// LanguageModelCompletionError,
3004// >,
3005// > {
3006// if !*self.failed_once.lock() {
3007// *self.failed_once.lock() = true;
3008// // Return error on first attempt
3009// let stream = futures::stream::once(async move {
3010// Err(LanguageModelCompletionError::Overloaded)
3011// });
3012// async move { Ok(stream.boxed()) }.boxed()
3013// } else {
3014// // Succeed on retry
3015// self.inner.stream_completion(request, cx)
3016// }
3017// }
3018// }
3019
3020// let fail_once_model = Arc::new(FailOnceModel {
3021// inner: Arc::new(FakeLanguageModel::default()),
3022// failed_once: Arc::new(Mutex::new(false)),
3023// });
3024
3025// // Insert a user message
3026// thread.update(cx, |thread, cx| {
3027// thread.insert_user_message(
3028// "Test message",
3029// ContextLoadResult::default(),
3030// None,
3031// vec![],
3032// cx,
3033// );
3034// });
3035
3036// // Start completion with fail-once model
3037// thread.update(cx, |thread, cx| {
3038// thread.send_to_model(
3039// fail_once_model.clone(),
3040// CompletionIntent::UserPrompt,
3041// None,
3042// cx,
3043// );
3044// });
3045
3046// cx.run_until_parked();
3047
3048// // Verify retry state exists after first failure
3049// thread.read_with(cx, |thread, _| {
3050// assert!(
3051// thread.retry_state.is_some(),
3052// "Should have retry state after failure"
3053// );
3054// });
3055
3056// // Wait for retry delay
3057// cx.executor()
3058// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
3059// cx.run_until_parked();
3060
3061// // The retry should now use our FailOnceModel which should succeed
3062// // We need to help the FakeLanguageModel complete the stream
3063// let inner_fake = fail_once_model.inner.clone();
3064
3065// // Wait a bit for the retry to start
3066// cx.run_until_parked();
3067
3068// // Check for pending completions and complete them
3069// if let Some(pending) = inner_fake.pending_completions().first() {
3070// inner_fake.stream_completion_response(pending, "Success!");
3071// inner_fake.end_completion_stream(pending);
3072// }
3073// cx.run_until_parked();
3074
3075// thread.read_with(cx, |thread, _| {
3076// assert!(
3077// thread.retry_state.is_none(),
3078// "Retry state should be cleared after successful completion"
3079// );
3080
3081// let has_assistant_message = thread
3082// .messages
3083// .iter()
3084// .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
3085// assert!(
3086// has_assistant_message,
3087// "Should have an assistant message after successful retry"
3088// );
3089// });
3090// }
3091
3092// #[gpui::test]
3093// async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
3094// init_test_settings(cx);
3095
3096// let project = create_test_project(cx, json!({})).await;
3097// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3098
3099// // Create a model that returns rate limit error with retry_after
3100// struct RateLimitModel {
3101// inner: Arc<FakeLanguageModel>,
3102// }
3103
3104// impl LanguageModel for RateLimitModel {
3105// fn id(&self) -> LanguageModelId {
3106// self.inner.id()
3107// }
3108
3109// fn name(&self) -> LanguageModelName {
3110// self.inner.name()
3111// }
3112
3113// fn provider_id(&self) -> LanguageModelProviderId {
3114// self.inner.provider_id()
3115// }
3116
3117// fn provider_name(&self) -> LanguageModelProviderName {
3118// self.inner.provider_name()
3119// }
3120
3121// fn supports_tools(&self) -> bool {
3122// self.inner.supports_tools()
3123// }
3124
3125// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
3126// self.inner.supports_tool_choice(choice)
3127// }
3128
3129// fn supports_images(&self) -> bool {
3130// self.inner.supports_images()
3131// }
3132
3133// fn telemetry_id(&self) -> String {
3134// self.inner.telemetry_id()
3135// }
3136
3137// fn max_token_count(&self) -> u64 {
3138// self.inner.max_token_count()
3139// }
3140
3141// fn count_tokens(
3142// &self,
3143// request: LanguageModelRequest,
3144// cx: &App,
3145// ) -> BoxFuture<'static, Result<u64>> {
3146// self.inner.count_tokens(request, cx)
3147// }
3148
3149// fn stream_completion(
3150// &self,
3151// _request: LanguageModelRequest,
3152// _cx: &AsyncApp,
3153// ) -> BoxFuture<
3154// 'static,
3155// Result<
3156// BoxStream<
3157// 'static,
3158// Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
3159// >,
3160// LanguageModelCompletionError,
3161// >,
3162// > {
3163// async move {
3164// let stream = futures::stream::once(async move {
3165// Err(LanguageModelCompletionError::RateLimitExceeded {
3166// retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
3167// })
3168// });
3169// Ok(stream.boxed())
3170// }
3171// .boxed()
3172// }
3173
3174// fn as_fake(&self) -> &FakeLanguageModel {
3175// &self.inner
3176// }
3177// }
3178
3179// let model = Arc::new(RateLimitModel {
3180// inner: Arc::new(FakeLanguageModel::default()),
3181// });
3182
3183// // Insert a user message
3184// thread.update(cx, |thread, cx| {
3185// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3186// });
3187
3188// // Start completion
3189// thread.update(cx, |thread, cx| {
3190// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3191// });
3192
3193// cx.run_until_parked();
3194
3195// let retry_count = thread.update(cx, |thread, _| {
3196// thread
3197// .messages
3198// .iter()
3199// .filter(|m| {
3200// m.ui_only
3201// && m.segments.iter().any(|s| {
3202// if let MessageSegment::Text(text) = s {
3203// text.contains("rate limit exceeded")
3204// } else {
3205// false
3206// }
3207// })
3208// })
3209// .count()
3210// });
3211// assert_eq!(retry_count, 1, "Should have scheduled one retry");
3212
3213// thread.read_with(cx, |thread, _| {
3214// assert!(
3215// thread.retry_state.is_none(),
3216// "Rate limit errors should not set retry_state"
3217// );
3218// });
3219
3220// // Verify we have one retry message
3221// thread.read_with(cx, |thread, _| {
3222// let retry_messages = thread
3223// .messages
3224// .iter()
3225// .filter(|msg| {
3226// msg.ui_only
3227// && msg.segments.iter().any(|seg| {
3228// if let MessageSegment::Text(text) = seg {
3229// text.contains("rate limit exceeded")
3230// } else {
3231// false
3232// }
3233// })
3234// })
3235// .count();
3236// assert_eq!(
3237// retry_messages, 1,
3238// "Should have one rate limit retry message"
3239// );
3240// });
3241
3242// // Check that retry message doesn't include attempt count
3243// thread.read_with(cx, |thread, _| {
3244// let retry_message = thread
3245// .messages
3246// .iter()
3247// .find(|msg| msg.role == Role::System && msg.ui_only)
3248// .expect("Should have a retry message");
3249
3250// // Check that the message doesn't contain attempt count
3251// if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
3252// assert!(
3253// !text.contains("attempt"),
3254// "Rate limit retry message should not contain attempt count"
3255// );
3256// assert!(
3257// text.contains(&format!(
3258// "Retrying in {} seconds",
3259// TEST_RATE_LIMIT_RETRY_SECS
3260// )),
3261// "Rate limit retry message should contain retry delay"
3262// );
3263// }
3264// });
3265// }
3266
3267// #[gpui::test]
3268// async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
3269// init_test_settings(cx);
3270
3271// let project = create_test_project(cx, json!({})).await;
3272// let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
3273
3274// // Insert a regular user message
3275// thread.update(cx, |thread, cx| {
3276// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3277// });
3278
3279// // Insert a UI-only message (like our retry notifications)
3280// thread.update(cx, |thread, cx| {
3281// let id = thread.next_message_id.post_inc();
3282// thread.messages.push(Message {
3283// id,
3284// role: Role::System,
3285// segments: vec![MessageSegment::Text(
3286// "This is a UI-only message that should not be sent to the model".to_string(),
3287// )],
3288// loaded_context: LoadedContext::default(),
3289// creases: Vec::new(),
3290// is_hidden: true,
3291// ui_only: true,
3292// });
3293// cx.emit(ThreadEvent::MessageAdded(id));
3294// });
3295
3296// // Insert another regular message
3297// thread.update(cx, |thread, cx| {
3298// thread.insert_user_message(
3299// "How are you?",
3300// ContextLoadResult::default(),
3301// None,
3302// vec![],
3303// cx,
3304// );
3305// });
3306
3307// // Generate the completion request
3308// let request = thread.update(cx, |thread, cx| {
3309// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3310// });
3311
3312// // Verify that the request only contains non-UI-only messages
3313// // Should have system prompt + 2 user messages, but not the UI-only message
3314// let user_messages: Vec<_> = request
3315// .messages
3316// .iter()
3317// .filter(|msg| msg.role == Role::User)
3318// .collect();
3319// assert_eq!(
3320// user_messages.len(),
3321// 2,
3322// "Should have exactly 2 user messages"
3323// );
3324
3325// // Verify the UI-only content is not present anywhere in the request
3326// let request_text = request
3327// .messages
3328// .iter()
3329// .flat_map(|msg| &msg.content)
3330// .filter_map(|content| match content {
3331// MessageContent::Text(text) => Some(text.as_str()),
3332// _ => None,
3333// })
3334// .collect::<String>();
3335
3336// assert!(
3337// !request_text.contains("UI-only message"),
3338// "UI-only message content should not be in the request"
3339// );
3340
3341// // Verify the thread still has all 3 messages (including UI-only)
3342// thread.read_with(cx, |thread, _| {
3343// assert_eq!(
3344// thread.messages().count(),
3345// 3,
3346// "Thread should have 3 messages"
3347// );
3348// assert_eq!(
3349// thread.messages().filter(|m| m.ui_only).count(),
3350// 1,
3351// "Thread should have 1 UI-only message"
3352// );
3353// });
3354
3355// // Verify that UI-only messages are not serialized
3356// let serialized = thread
3357// .update(cx, |thread, cx| thread.serialize(cx))
3358// .await
3359// .unwrap();
3360// assert_eq!(
3361// serialized.messages.len(),
3362// 2,
3363// "Serialized thread should only have 2 messages (no UI-only)"
3364// );
3365// }
3366
3367// #[gpui::test]
3368// async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
3369// init_test_settings(cx);
3370
3371// let project = create_test_project(cx, json!({})).await;
3372// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
3373
3374// // Create model that returns overloaded error
3375// let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
3376
3377// // Insert a user message
3378// thread.update(cx, |thread, cx| {
3379// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
3380// });
3381
3382// // Start completion
3383// thread.update(cx, |thread, cx| {
3384// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3385// });
3386
3387// cx.run_until_parked();
3388
3389// // Verify retry was scheduled by checking for retry message
3390// let has_retry_message = thread.read_with(cx, |thread, _| {
3391// thread.messages.iter().any(|m| {
3392// m.ui_only
3393// && m.segments.iter().any(|s| {
3394// if let MessageSegment::Text(text) = s {
3395// text.contains("Retrying") && text.contains("seconds")
3396// } else {
3397// false
3398// }
3399// })
3400// })
3401// });
3402// assert!(has_retry_message, "Should have scheduled a retry");
3403
3404// // Cancel the completion before the retry happens
3405// thread.update(cx, |thread, cx| {
3406// thread.cancel_last_completion(None, cx);
3407// });
3408
3409// cx.run_until_parked();
3410
3411// // The retry should not have happened - no pending completions
3412// let fake_model = model.as_fake();
3413// assert_eq!(
3414// fake_model.pending_completions().len(),
3415// 0,
3416// "Should have no pending completions after cancellation"
3417// );
3418
3419// // Verify the retry was cancelled by checking retry state
3420// thread.read_with(cx, |thread, _| {
3421// if let Some(retry_state) = &thread.retry_state {
3422// panic!(
3423// "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
3424// retry_state.attempt, retry_state.max_attempts, retry_state.intent
3425// );
3426// }
3427// });
3428// }
3429
3430// fn test_summarize_error(
3431// model: &Arc<dyn LanguageModel>,
3432// thread: &Entity<Thread>,
3433// cx: &mut TestAppContext,
3434// ) {
3435// thread.update(cx, |thread, cx| {
3436// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3437// thread.send_to_model(
3438// model.clone(),
3439// CompletionIntent::ThreadSummarization,
3440// None,
3441// cx,
3442// );
3443// });
3444
3445// let fake_model = model.as_fake();
3446// simulate_successful_response(&fake_model, cx);
3447
3448// thread.read_with(cx, |thread, _| {
3449// assert!(matches!(thread.summary(), ThreadSummary::Generating));
3450// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3451// });
3452
3453// // Simulate summary request ending
3454// cx.run_until_parked();
3455// fake_model.end_last_completion_stream();
3456// cx.run_until_parked();
3457
3458// // State is set to Error and default message
3459// thread.read_with(cx, |thread, _| {
3460// assert!(matches!(thread.summary(), ThreadSummary::Error));
3461// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3462// });
3463// }
3464
3465// fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3466// cx.run_until_parked();
3467// fake_model.stream_last_completion_response("Assistant response");
3468// fake_model.end_last_completion_stream();
3469// cx.run_until_parked();
3470// }
3471
3472// fn init_test_settings(cx: &mut TestAppContext) {
3473// cx.update(|cx| {
3474// let settings_store = SettingsStore::test(cx);
3475// cx.set_global(settings_store);
3476// language::init(cx);
3477// Project::init_settings(cx);
3478// AgentSettings::register(cx);
3479// prompt_store::init(cx);
3480// thread_store::init(cx);
3481// workspace::init_settings(cx);
3482// language_model::init_settings(cx);
3483// ThemeSettings::register(cx);
3484// ToolRegistry::default_global(cx);
3485// });
3486// }
3487
3488// // Helper to create a test project with test files
3489// async fn create_test_project(
3490// cx: &mut TestAppContext,
3491// files: serde_json::Value,
3492// ) -> Entity<Project> {
3493// let fs = FakeFs::new(cx.executor());
3494// fs.insert_tree(path!("/test"), files).await;
3495// Project::test(fs, [path!("/test").as_ref()], cx).await
3496// }
3497
3498// async fn setup_test_environment(
3499// cx: &mut TestAppContext,
3500// project: Entity<Project>,
3501// ) -> (
3502// Entity<Workspace>,
3503// Entity<ThreadStore>,
3504// Entity<Thread>,
3505// Entity<ContextStore>,
3506// Arc<dyn LanguageModel>,
3507// ) {
3508// let (workspace, cx) =
3509// cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3510
3511// let thread_store = cx
3512// .update(|_, cx| {
3513// ThreadStore::load(
3514// project.clone(),
3515// cx.new(|_| ToolWorkingSet::default()),
3516// None,
3517// Arc::new(PromptBuilder::new(None).unwrap()),
3518// cx,
3519// )
3520// })
3521// .await
3522// .unwrap();
3523
3524// let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3525// let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3526
3527// let provider = Arc::new(FakeLanguageModelProvider);
3528// let model = provider.test_model();
3529// let model: Arc<dyn LanguageModel> = Arc::new(model);
3530
3531// cx.update(|_, cx| {
3532// LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3533// registry.set_default_model(
3534// Some(ConfiguredModel {
3535// provider: provider.clone(),
3536// model: model.clone(),
3537// }),
3538// cx,
3539// );
3540// registry.set_thread_summary_model(
3541// Some(ConfiguredModel {
3542// provider,
3543// model: model.clone(),
3544// }),
3545// cx,
3546// );
3547// })
3548// });
3549
3550// (workspace, thread_store, thread, context_store, model)
3551// }
3552
3553// async fn add_file_to_context(
3554// project: &Entity<Project>,
3555// context_store: &Entity<ContextStore>,
3556// path: &str,
3557// cx: &mut TestAppContext,
3558// ) -> Result<Entity<language::Buffer>> {
3559// let buffer_path = project
3560// .read_with(cx, |project, cx| project.find_project_path(path, cx))
3561// .unwrap();
3562
3563// let buffer = project
3564// .update(cx, |project, cx| {
3565// project.open_buffer(buffer_path.clone(), cx)
3566// })
3567// .await
3568// .unwrap();
3569
3570// context_store.update(cx, |context_store, cx| {
3571// context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3572// });
3573
3574// Ok(buffer)
3575// }
3576// }