1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::CompletionMode;
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315fn default_completion_mode(cx: &App) -> CompletionMode {
316 if cx.is_staff() {
317 CompletionMode::Max
318 } else {
319 CompletionMode::Normal
320 }
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: Option<SharedString>,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 tool_use_limit_reached: bool,
359 feedback: Option<ThreadFeedback>,
360 message_feedback: HashMap<MessageId, ThreadFeedback>,
361 last_auto_capture_at: Option<Instant>,
362 last_received_chunk_at: Option<Instant>,
363 request_callback: Option<
364 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
365 >,
366 remaining_turns: u32,
367 configured_model: Option<ConfiguredModel>,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct ExceededWindowError {
372 /// Model used when last message exceeded context window
373 model_id: LanguageModelId,
374 /// Token count including last message
375 token_count: usize,
376}
377
378impl Thread {
379 pub fn new(
380 project: Entity<Project>,
381 tools: Entity<ToolWorkingSet>,
382 prompt_builder: Arc<PromptBuilder>,
383 system_prompt: SharedProjectContext,
384 cx: &mut Context<Self>,
385 ) -> Self {
386 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
387 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
388
389 Self {
390 id: ThreadId::new(),
391 updated_at: Utc::now(),
392 summary: None,
393 pending_summary: Task::ready(None),
394 detailed_summary_task: Task::ready(None),
395 detailed_summary_tx,
396 detailed_summary_rx,
397 completion_mode: default_completion_mode(cx),
398 messages: Vec::new(),
399 next_message_id: MessageId(0),
400 last_prompt_id: PromptId::new(),
401 project_context: system_prompt,
402 checkpoints_by_message: HashMap::default(),
403 completion_count: 0,
404 pending_completions: Vec::new(),
405 project: project.clone(),
406 prompt_builder,
407 tools: tools.clone(),
408 last_restore_checkpoint: None,
409 pending_checkpoint: None,
410 tool_use: ToolUseState::new(tools.clone()),
411 action_log: cx.new(|_| ActionLog::new(project.clone())),
412 initial_project_snapshot: {
413 let project_snapshot = Self::project_snapshot(project, cx);
414 cx.foreground_executor()
415 .spawn(async move { Some(project_snapshot.await) })
416 .shared()
417 },
418 request_token_usage: Vec::new(),
419 cumulative_token_usage: TokenUsage::default(),
420 exceeded_window_error: None,
421 tool_use_limit_reached: false,
422 feedback: None,
423 message_feedback: HashMap::default(),
424 last_auto_capture_at: None,
425 last_received_chunk_at: None,
426 request_callback: None,
427 remaining_turns: u32::MAX,
428 configured_model,
429 }
430 }
431
432 pub fn deserialize(
433 id: ThreadId,
434 serialized: SerializedThread,
435 project: Entity<Project>,
436 tools: Entity<ToolWorkingSet>,
437 prompt_builder: Arc<PromptBuilder>,
438 project_context: SharedProjectContext,
439 cx: &mut Context<Self>,
440 ) -> Self {
441 let next_message_id = MessageId(
442 serialized
443 .messages
444 .last()
445 .map(|message| message.id.0 + 1)
446 .unwrap_or(0),
447 );
448 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
449 let (detailed_summary_tx, detailed_summary_rx) =
450 postage::watch::channel_with(serialized.detailed_summary_state);
451
452 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
453 serialized
454 .model
455 .and_then(|model| {
456 let model = SelectedModel {
457 provider: model.provider.clone().into(),
458 model: model.model.clone().into(),
459 };
460 registry.select_model(&model, cx)
461 })
462 .or_else(|| registry.default_model())
463 });
464
465 Self {
466 id,
467 updated_at: serialized.updated_at,
468 summary: Some(serialized.summary),
469 pending_summary: Task::ready(None),
470 detailed_summary_task: Task::ready(None),
471 detailed_summary_tx,
472 detailed_summary_rx,
473 completion_mode: default_completion_mode(cx),
474 messages: serialized
475 .messages
476 .into_iter()
477 .map(|message| Message {
478 id: message.id,
479 role: message.role,
480 segments: message
481 .segments
482 .into_iter()
483 .map(|segment| match segment {
484 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
485 SerializedMessageSegment::Thinking { text, signature } => {
486 MessageSegment::Thinking { text, signature }
487 }
488 SerializedMessageSegment::RedactedThinking { data } => {
489 MessageSegment::RedactedThinking(data)
490 }
491 })
492 .collect(),
493 loaded_context: LoadedContext {
494 contexts: Vec::new(),
495 text: message.context,
496 images: Vec::new(),
497 },
498 creases: message
499 .creases
500 .into_iter()
501 .map(|crease| MessageCrease {
502 range: crease.start..crease.end,
503 metadata: CreaseMetadata {
504 icon_path: crease.icon_path,
505 label: crease.label,
506 },
507 context: None,
508 })
509 .collect(),
510 })
511 .collect(),
512 next_message_id,
513 last_prompt_id: PromptId::new(),
514 project_context,
515 checkpoints_by_message: HashMap::default(),
516 completion_count: 0,
517 pending_completions: Vec::new(),
518 last_restore_checkpoint: None,
519 pending_checkpoint: None,
520 project: project.clone(),
521 prompt_builder,
522 tools,
523 tool_use,
524 action_log: cx.new(|_| ActionLog::new(project)),
525 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
526 request_token_usage: serialized.request_token_usage,
527 cumulative_token_usage: serialized.cumulative_token_usage,
528 exceeded_window_error: None,
529 tool_use_limit_reached: false,
530 feedback: None,
531 message_feedback: HashMap::default(),
532 last_auto_capture_at: None,
533 last_received_chunk_at: None,
534 request_callback: None,
535 remaining_turns: u32::MAX,
536 configured_model,
537 }
538 }
539
540 pub fn set_request_callback(
541 &mut self,
542 callback: impl 'static
543 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
544 ) {
545 self.request_callback = Some(Box::new(callback));
546 }
547
548 pub fn id(&self) -> &ThreadId {
549 &self.id
550 }
551
552 pub fn is_empty(&self) -> bool {
553 self.messages.is_empty()
554 }
555
556 pub fn updated_at(&self) -> DateTime<Utc> {
557 self.updated_at
558 }
559
560 pub fn touch_updated_at(&mut self) {
561 self.updated_at = Utc::now();
562 }
563
564 pub fn advance_prompt_id(&mut self) {
565 self.last_prompt_id = PromptId::new();
566 }
567
568 pub fn summary(&self) -> Option<SharedString> {
569 self.summary.clone()
570 }
571
572 pub fn project_context(&self) -> SharedProjectContext {
573 self.project_context.clone()
574 }
575
576 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
577 if self.configured_model.is_none() {
578 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
579 }
580 self.configured_model.clone()
581 }
582
583 pub fn configured_model(&self) -> Option<ConfiguredModel> {
584 self.configured_model.clone()
585 }
586
587 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
588 self.configured_model = model;
589 cx.notify();
590 }
591
592 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
593
594 pub fn summary_or_default(&self) -> SharedString {
595 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
596 }
597
598 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
599 let Some(current_summary) = &self.summary else {
600 // Don't allow setting summary until generated
601 return;
602 };
603
604 let mut new_summary = new_summary.into();
605
606 if new_summary.is_empty() {
607 new_summary = Self::DEFAULT_SUMMARY;
608 }
609
610 if current_summary != &new_summary {
611 self.summary = Some(new_summary);
612 cx.emit(ThreadEvent::SummaryChanged);
613 }
614 }
615
616 pub fn completion_mode(&self) -> CompletionMode {
617 self.completion_mode
618 }
619
620 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
621 self.completion_mode = mode;
622 }
623
624 pub fn message(&self, id: MessageId) -> Option<&Message> {
625 let index = self
626 .messages
627 .binary_search_by(|message| message.id.cmp(&id))
628 .ok()?;
629
630 self.messages.get(index)
631 }
632
633 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
634 self.messages.iter()
635 }
636
637 pub fn is_generating(&self) -> bool {
638 !self.pending_completions.is_empty() || !self.all_tools_finished()
639 }
640
641 /// Indicates whether streaming of language model events is stale.
642 /// When `is_generating()` is false, this method returns `None`.
643 pub fn is_generation_stale(&self) -> Option<bool> {
644 const STALE_THRESHOLD: u128 = 250;
645
646 self.last_received_chunk_at
647 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
648 }
649
650 fn received_chunk(&mut self) {
651 self.last_received_chunk_at = Some(Instant::now());
652 }
653
654 pub fn queue_state(&self) -> Option<QueueState> {
655 self.pending_completions
656 .first()
657 .map(|pending_completion| pending_completion.queue_state)
658 }
659
660 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
661 &self.tools
662 }
663
664 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
665 self.tool_use
666 .pending_tool_uses()
667 .into_iter()
668 .find(|tool_use| &tool_use.id == id)
669 }
670
671 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
672 self.tool_use
673 .pending_tool_uses()
674 .into_iter()
675 .filter(|tool_use| tool_use.status.needs_confirmation())
676 }
677
678 pub fn has_pending_tool_uses(&self) -> bool {
679 !self.tool_use.pending_tool_uses().is_empty()
680 }
681
682 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
683 self.checkpoints_by_message.get(&id).cloned()
684 }
685
686 pub fn restore_checkpoint(
687 &mut self,
688 checkpoint: ThreadCheckpoint,
689 cx: &mut Context<Self>,
690 ) -> Task<Result<()>> {
691 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
692 message_id: checkpoint.message_id,
693 });
694 cx.emit(ThreadEvent::CheckpointChanged);
695 cx.notify();
696
697 let git_store = self.project().read(cx).git_store().clone();
698 let restore = git_store.update(cx, |git_store, cx| {
699 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
700 });
701
702 cx.spawn(async move |this, cx| {
703 let result = restore.await;
704 this.update(cx, |this, cx| {
705 if let Err(err) = result.as_ref() {
706 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
707 message_id: checkpoint.message_id,
708 error: err.to_string(),
709 });
710 } else {
711 this.truncate(checkpoint.message_id, cx);
712 this.last_restore_checkpoint = None;
713 }
714 this.pending_checkpoint = None;
715 cx.emit(ThreadEvent::CheckpointChanged);
716 cx.notify();
717 })?;
718 result
719 })
720 }
721
722 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
723 let pending_checkpoint = if self.is_generating() {
724 return;
725 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
726 checkpoint
727 } else {
728 return;
729 };
730
731 let git_store = self.project.read(cx).git_store().clone();
732 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
733 cx.spawn(async move |this, cx| match final_checkpoint.await {
734 Ok(final_checkpoint) => {
735 let equal = git_store
736 .update(cx, |store, cx| {
737 store.compare_checkpoints(
738 pending_checkpoint.git_checkpoint.clone(),
739 final_checkpoint.clone(),
740 cx,
741 )
742 })?
743 .await
744 .unwrap_or(false);
745
746 if !equal {
747 this.update(cx, |this, cx| {
748 this.insert_checkpoint(pending_checkpoint, cx)
749 })?;
750 }
751
752 Ok(())
753 }
754 Err(_) => this.update(cx, |this, cx| {
755 this.insert_checkpoint(pending_checkpoint, cx)
756 }),
757 })
758 .detach();
759 }
760
761 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
762 self.checkpoints_by_message
763 .insert(checkpoint.message_id, checkpoint);
764 cx.emit(ThreadEvent::CheckpointChanged);
765 cx.notify();
766 }
767
768 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
769 self.last_restore_checkpoint.as_ref()
770 }
771
772 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
773 let Some(message_ix) = self
774 .messages
775 .iter()
776 .rposition(|message| message.id == message_id)
777 else {
778 return;
779 };
780 for deleted_message in self.messages.drain(message_ix..) {
781 self.checkpoints_by_message.remove(&deleted_message.id);
782 }
783 cx.notify();
784 }
785
786 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
787 self.messages
788 .iter()
789 .find(|message| message.id == id)
790 .into_iter()
791 .flat_map(|message| message.loaded_context.contexts.iter())
792 }
793
794 pub fn is_turn_end(&self, ix: usize) -> bool {
795 if self.messages.is_empty() {
796 return false;
797 }
798
799 if !self.is_generating() && ix == self.messages.len() - 1 {
800 return true;
801 }
802
803 let Some(message) = self.messages.get(ix) else {
804 return false;
805 };
806
807 if message.role != Role::Assistant {
808 return false;
809 }
810
811 self.messages
812 .get(ix + 1)
813 .and_then(|message| {
814 self.message(message.id)
815 .map(|next_message| next_message.role == Role::User)
816 })
817 .unwrap_or(false)
818 }
819
820 pub fn tool_use_limit_reached(&self) -> bool {
821 self.tool_use_limit_reached
822 }
823
824 /// Returns whether all of the tool uses have finished running.
825 pub fn all_tools_finished(&self) -> bool {
826 // If the only pending tool uses left are the ones with errors, then
827 // that means that we've finished running all of the pending tools.
828 self.tool_use
829 .pending_tool_uses()
830 .iter()
831 .all(|tool_use| tool_use.status.is_error())
832 }
833
834 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
835 self.tool_use.tool_uses_for_message(id, cx)
836 }
837
838 pub fn tool_results_for_message(
839 &self,
840 assistant_message_id: MessageId,
841 ) -> Vec<&LanguageModelToolResult> {
842 self.tool_use.tool_results_for_message(assistant_message_id)
843 }
844
845 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
846 self.tool_use.tool_result(id)
847 }
848
849 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
850 Some(&self.tool_use.tool_result(id)?.content)
851 }
852
853 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
854 self.tool_use.tool_result_card(id).cloned()
855 }
856
857 /// Return tools that are both enabled and supported by the model
858 pub fn available_tools(
859 &self,
860 cx: &App,
861 model: Arc<dyn LanguageModel>,
862 ) -> Vec<LanguageModelRequestTool> {
863 if model.supports_tools() {
864 self.tools()
865 .read(cx)
866 .enabled_tools(cx)
867 .into_iter()
868 .filter_map(|tool| {
869 // Skip tools that cannot be supported
870 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
871 Some(LanguageModelRequestTool {
872 name: tool.name(),
873 description: tool.description(),
874 input_schema,
875 })
876 })
877 .collect()
878 } else {
879 Vec::default()
880 }
881 }
882
883 pub fn insert_user_message(
884 &mut self,
885 text: impl Into<String>,
886 loaded_context: ContextLoadResult,
887 git_checkpoint: Option<GitStoreCheckpoint>,
888 creases: Vec<MessageCrease>,
889 cx: &mut Context<Self>,
890 ) -> MessageId {
891 if !loaded_context.referenced_buffers.is_empty() {
892 self.action_log.update(cx, |log, cx| {
893 for buffer in loaded_context.referenced_buffers {
894 log.track_buffer(buffer, cx);
895 }
896 });
897 }
898
899 let message_id = self.insert_message(
900 Role::User,
901 vec![MessageSegment::Text(text.into())],
902 loaded_context.loaded_context,
903 creases,
904 cx,
905 );
906
907 if let Some(git_checkpoint) = git_checkpoint {
908 self.pending_checkpoint = Some(ThreadCheckpoint {
909 message_id,
910 git_checkpoint,
911 });
912 }
913
914 self.auto_capture_telemetry(cx);
915
916 message_id
917 }
918
919 pub fn insert_assistant_message(
920 &mut self,
921 segments: Vec<MessageSegment>,
922 cx: &mut Context<Self>,
923 ) -> MessageId {
924 self.insert_message(
925 Role::Assistant,
926 segments,
927 LoadedContext::default(),
928 Vec::new(),
929 cx,
930 )
931 }
932
933 pub fn insert_message(
934 &mut self,
935 role: Role,
936 segments: Vec<MessageSegment>,
937 loaded_context: LoadedContext,
938 creases: Vec<MessageCrease>,
939 cx: &mut Context<Self>,
940 ) -> MessageId {
941 let id = self.next_message_id.post_inc();
942 self.messages.push(Message {
943 id,
944 role,
945 segments,
946 loaded_context,
947 creases,
948 });
949 self.touch_updated_at();
950 cx.emit(ThreadEvent::MessageAdded(id));
951 id
952 }
953
954 pub fn edit_message(
955 &mut self,
956 id: MessageId,
957 new_role: Role,
958 new_segments: Vec<MessageSegment>,
959 loaded_context: Option<LoadedContext>,
960 cx: &mut Context<Self>,
961 ) -> bool {
962 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
963 return false;
964 };
965 message.role = new_role;
966 message.segments = new_segments;
967 if let Some(context) = loaded_context {
968 message.loaded_context = context;
969 }
970 self.touch_updated_at();
971 cx.emit(ThreadEvent::MessageEdited(id));
972 true
973 }
974
975 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
976 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
977 return false;
978 };
979 self.messages.remove(index);
980 self.touch_updated_at();
981 cx.emit(ThreadEvent::MessageDeleted(id));
982 true
983 }
984
985 /// Returns the representation of this [`Thread`] in a textual form.
986 ///
987 /// This is the representation we use when attaching a thread as context to another thread.
988 pub fn text(&self) -> String {
989 let mut text = String::new();
990
991 for message in &self.messages {
992 text.push_str(match message.role {
993 language_model::Role::User => "User:",
994 language_model::Role::Assistant => "Assistant:",
995 language_model::Role::System => "System:",
996 });
997 text.push('\n');
998
999 for segment in &message.segments {
1000 match segment {
1001 MessageSegment::Text(content) => text.push_str(content),
1002 MessageSegment::Thinking { text: content, .. } => {
1003 text.push_str(&format!("<think>{}</think>", content))
1004 }
1005 MessageSegment::RedactedThinking(_) => {}
1006 }
1007 }
1008 text.push('\n');
1009 }
1010
1011 text
1012 }
1013
1014 /// Serializes this thread into a format for storage or telemetry.
1015 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1016 let initial_project_snapshot = self.initial_project_snapshot.clone();
1017 cx.spawn(async move |this, cx| {
1018 let initial_project_snapshot = initial_project_snapshot.await;
1019 this.read_with(cx, |this, cx| SerializedThread {
1020 version: SerializedThread::VERSION.to_string(),
1021 summary: this.summary_or_default(),
1022 updated_at: this.updated_at(),
1023 messages: this
1024 .messages()
1025 .map(|message| SerializedMessage {
1026 id: message.id,
1027 role: message.role,
1028 segments: message
1029 .segments
1030 .iter()
1031 .map(|segment| match segment {
1032 MessageSegment::Text(text) => {
1033 SerializedMessageSegment::Text { text: text.clone() }
1034 }
1035 MessageSegment::Thinking { text, signature } => {
1036 SerializedMessageSegment::Thinking {
1037 text: text.clone(),
1038 signature: signature.clone(),
1039 }
1040 }
1041 MessageSegment::RedactedThinking(data) => {
1042 SerializedMessageSegment::RedactedThinking {
1043 data: data.clone(),
1044 }
1045 }
1046 })
1047 .collect(),
1048 tool_uses: this
1049 .tool_uses_for_message(message.id, cx)
1050 .into_iter()
1051 .map(|tool_use| SerializedToolUse {
1052 id: tool_use.id,
1053 name: tool_use.name,
1054 input: tool_use.input,
1055 })
1056 .collect(),
1057 tool_results: this
1058 .tool_results_for_message(message.id)
1059 .into_iter()
1060 .map(|tool_result| SerializedToolResult {
1061 tool_use_id: tool_result.tool_use_id.clone(),
1062 is_error: tool_result.is_error,
1063 content: tool_result.content.clone(),
1064 })
1065 .collect(),
1066 context: message.loaded_context.text.clone(),
1067 creases: message
1068 .creases
1069 .iter()
1070 .map(|crease| SerializedCrease {
1071 start: crease.range.start,
1072 end: crease.range.end,
1073 icon_path: crease.metadata.icon_path.clone(),
1074 label: crease.metadata.label.clone(),
1075 })
1076 .collect(),
1077 })
1078 .collect(),
1079 initial_project_snapshot,
1080 cumulative_token_usage: this.cumulative_token_usage,
1081 request_token_usage: this.request_token_usage.clone(),
1082 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1083 exceeded_window_error: this.exceeded_window_error.clone(),
1084 model: this
1085 .configured_model
1086 .as_ref()
1087 .map(|model| SerializedLanguageModel {
1088 provider: model.provider.id().0.to_string(),
1089 model: model.model.id().0.to_string(),
1090 }),
1091 })
1092 })
1093 }
1094
1095 pub fn remaining_turns(&self) -> u32 {
1096 self.remaining_turns
1097 }
1098
1099 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1100 self.remaining_turns = remaining_turns;
1101 }
1102
1103 pub fn send_to_model(
1104 &mut self,
1105 model: Arc<dyn LanguageModel>,
1106 window: Option<AnyWindowHandle>,
1107 cx: &mut Context<Self>,
1108 ) {
1109 if self.remaining_turns == 0 {
1110 return;
1111 }
1112
1113 self.remaining_turns -= 1;
1114
1115 let request = self.to_completion_request(model.clone(), cx);
1116
1117 self.stream_completion(request, model, window, cx);
1118 }
1119
1120 pub fn used_tools_since_last_user_message(&self) -> bool {
1121 for message in self.messages.iter().rev() {
1122 if self.tool_use.message_has_tool_results(message.id) {
1123 return true;
1124 } else if message.role == Role::User {
1125 return false;
1126 }
1127 }
1128
1129 false
1130 }
1131
1132 pub fn to_completion_request(
1133 &self,
1134 model: Arc<dyn LanguageModel>,
1135 cx: &mut Context<Self>,
1136 ) -> LanguageModelRequest {
1137 let mut request = LanguageModelRequest {
1138 thread_id: Some(self.id.to_string()),
1139 prompt_id: Some(self.last_prompt_id.to_string()),
1140 mode: None,
1141 messages: vec![],
1142 tools: Vec::new(),
1143 stop: Vec::new(),
1144 temperature: None,
1145 };
1146
1147 let available_tools = self.available_tools(cx, model.clone());
1148 let available_tool_names = available_tools
1149 .iter()
1150 .map(|tool| tool.name.clone())
1151 .collect();
1152
1153 let model_context = &ModelContext {
1154 available_tools: available_tool_names,
1155 };
1156
1157 if let Some(project_context) = self.project_context.borrow().as_ref() {
1158 match self
1159 .prompt_builder
1160 .generate_assistant_system_prompt(project_context, model_context)
1161 {
1162 Err(err) => {
1163 let message = format!("{err:?}").into();
1164 log::error!("{message}");
1165 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1166 header: "Error generating system prompt".into(),
1167 message,
1168 }));
1169 }
1170 Ok(system_prompt) => {
1171 request.messages.push(LanguageModelRequestMessage {
1172 role: Role::System,
1173 content: vec![MessageContent::Text(system_prompt)],
1174 cache: true,
1175 });
1176 }
1177 }
1178 } else {
1179 let message = "Context for system prompt unexpectedly not ready.".into();
1180 log::error!("{message}");
1181 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1182 header: "Error generating system prompt".into(),
1183 message,
1184 }));
1185 }
1186
1187 for message in &self.messages {
1188 let mut request_message = LanguageModelRequestMessage {
1189 role: message.role,
1190 content: Vec::new(),
1191 cache: false,
1192 };
1193
1194 message
1195 .loaded_context
1196 .add_to_request_message(&mut request_message);
1197
1198 for segment in &message.segments {
1199 match segment {
1200 MessageSegment::Text(text) => {
1201 if !text.is_empty() {
1202 request_message
1203 .content
1204 .push(MessageContent::Text(text.into()));
1205 }
1206 }
1207 MessageSegment::Thinking { text, signature } => {
1208 if !text.is_empty() {
1209 request_message.content.push(MessageContent::Thinking {
1210 text: text.into(),
1211 signature: signature.clone(),
1212 });
1213 }
1214 }
1215 MessageSegment::RedactedThinking(data) => {
1216 request_message
1217 .content
1218 .push(MessageContent::RedactedThinking(data.clone()));
1219 }
1220 };
1221 }
1222
1223 self.tool_use
1224 .attach_tool_uses(message.id, &mut request_message);
1225
1226 request.messages.push(request_message);
1227
1228 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1229 request.messages.push(tool_results_message);
1230 }
1231 }
1232
1233 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1234 if let Some(last) = request.messages.last_mut() {
1235 last.cache = true;
1236 }
1237
1238 self.attached_tracked_files_state(&mut request.messages, cx);
1239
1240 request.tools = available_tools;
1241 request.mode = if model.supports_max_mode() {
1242 Some(self.completion_mode)
1243 } else {
1244 Some(CompletionMode::Normal)
1245 };
1246
1247 request
1248 }
1249
1250 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1251 let mut request = LanguageModelRequest {
1252 thread_id: None,
1253 prompt_id: None,
1254 mode: None,
1255 messages: vec![],
1256 tools: Vec::new(),
1257 stop: Vec::new(),
1258 temperature: None,
1259 };
1260
1261 for message in &self.messages {
1262 let mut request_message = LanguageModelRequestMessage {
1263 role: message.role,
1264 content: Vec::new(),
1265 cache: false,
1266 };
1267
1268 for segment in &message.segments {
1269 match segment {
1270 MessageSegment::Text(text) => request_message
1271 .content
1272 .push(MessageContent::Text(text.clone())),
1273 MessageSegment::Thinking { .. } => {}
1274 MessageSegment::RedactedThinking(_) => {}
1275 }
1276 }
1277
1278 if request_message.content.is_empty() {
1279 continue;
1280 }
1281
1282 request.messages.push(request_message);
1283 }
1284
1285 request.messages.push(LanguageModelRequestMessage {
1286 role: Role::User,
1287 content: vec![MessageContent::Text(added_user_message)],
1288 cache: false,
1289 });
1290
1291 request
1292 }
1293
1294 fn attached_tracked_files_state(
1295 &self,
1296 messages: &mut Vec<LanguageModelRequestMessage>,
1297 cx: &App,
1298 ) {
1299 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1300
1301 let mut stale_message = String::new();
1302
1303 let action_log = self.action_log.read(cx);
1304
1305 for stale_file in action_log.stale_buffers(cx) {
1306 let Some(file) = stale_file.read(cx).file() else {
1307 continue;
1308 };
1309
1310 if stale_message.is_empty() {
1311 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1312 }
1313
1314 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1315 }
1316
1317 let mut content = Vec::with_capacity(2);
1318
1319 if !stale_message.is_empty() {
1320 content.push(stale_message.into());
1321 }
1322
1323 if !content.is_empty() {
1324 let context_message = LanguageModelRequestMessage {
1325 role: Role::User,
1326 content,
1327 cache: false,
1328 };
1329
1330 messages.push(context_message);
1331 }
1332 }
1333
1334 pub fn stream_completion(
1335 &mut self,
1336 request: LanguageModelRequest,
1337 model: Arc<dyn LanguageModel>,
1338 window: Option<AnyWindowHandle>,
1339 cx: &mut Context<Self>,
1340 ) {
1341 self.tool_use_limit_reached = false;
1342
1343 let pending_completion_id = post_inc(&mut self.completion_count);
1344 let mut request_callback_parameters = if self.request_callback.is_some() {
1345 Some((request.clone(), Vec::new()))
1346 } else {
1347 None
1348 };
1349 let prompt_id = self.last_prompt_id.clone();
1350 let tool_use_metadata = ToolUseMetadata {
1351 model: model.clone(),
1352 thread_id: self.id.clone(),
1353 prompt_id: prompt_id.clone(),
1354 };
1355
1356 self.last_received_chunk_at = Some(Instant::now());
1357
1358 let task = cx.spawn(async move |thread, cx| {
1359 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1360 let initial_token_usage =
1361 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1362 let stream_completion = async {
1363 let (mut events, usage) = stream_completion_future.await?;
1364
1365 let mut stop_reason = StopReason::EndTurn;
1366 let mut current_token_usage = TokenUsage::default();
1367
1368 thread
1369 .update(cx, |_thread, cx| {
1370 if let Some(usage) = usage {
1371 cx.emit(ThreadEvent::UsageUpdated(usage));
1372 }
1373 cx.emit(ThreadEvent::NewRequest);
1374 })
1375 .ok();
1376
1377 let mut request_assistant_message_id = None;
1378
1379 while let Some(event) = events.next().await {
1380 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1381 response_events
1382 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1383 }
1384
1385 thread.update(cx, |thread, cx| {
1386 let event = match event {
1387 Ok(event) => event,
1388 Err(LanguageModelCompletionError::BadInputJson {
1389 id,
1390 tool_name,
1391 raw_input: invalid_input_json,
1392 json_parse_error,
1393 }) => {
1394 thread.receive_invalid_tool_json(
1395 id,
1396 tool_name,
1397 invalid_input_json,
1398 json_parse_error,
1399 window,
1400 cx,
1401 );
1402 return Ok(());
1403 }
1404 Err(LanguageModelCompletionError::Other(error)) => {
1405 return Err(error);
1406 }
1407 };
1408
1409 match event {
1410 LanguageModelCompletionEvent::StartMessage { .. } => {
1411 request_assistant_message_id =
1412 Some(thread.insert_assistant_message(
1413 vec![MessageSegment::Text(String::new())],
1414 cx,
1415 ));
1416 }
1417 LanguageModelCompletionEvent::Stop(reason) => {
1418 stop_reason = reason;
1419 }
1420 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1421 thread.update_token_usage_at_last_message(token_usage);
1422 thread.cumulative_token_usage = thread.cumulative_token_usage
1423 + token_usage
1424 - current_token_usage;
1425 current_token_usage = token_usage;
1426 }
1427 LanguageModelCompletionEvent::Text(chunk) => {
1428 thread.received_chunk();
1429
1430 cx.emit(ThreadEvent::ReceivedTextChunk);
1431 if let Some(last_message) = thread.messages.last_mut() {
1432 if last_message.role == Role::Assistant
1433 && !thread.tool_use.has_tool_results(last_message.id)
1434 {
1435 last_message.push_text(&chunk);
1436 cx.emit(ThreadEvent::StreamedAssistantText(
1437 last_message.id,
1438 chunk,
1439 ));
1440 } else {
1441 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1442 // of a new Assistant response.
1443 //
1444 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1445 // will result in duplicating the text of the chunk in the rendered Markdown.
1446 request_assistant_message_id =
1447 Some(thread.insert_assistant_message(
1448 vec![MessageSegment::Text(chunk.to_string())],
1449 cx,
1450 ));
1451 };
1452 }
1453 }
1454 LanguageModelCompletionEvent::Thinking {
1455 text: chunk,
1456 signature,
1457 } => {
1458 thread.received_chunk();
1459
1460 if let Some(last_message) = thread.messages.last_mut() {
1461 if last_message.role == Role::Assistant
1462 && !thread.tool_use.has_tool_results(last_message.id)
1463 {
1464 last_message.push_thinking(&chunk, signature);
1465 cx.emit(ThreadEvent::StreamedAssistantThinking(
1466 last_message.id,
1467 chunk,
1468 ));
1469 } else {
1470 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1471 // of a new Assistant response.
1472 //
1473 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1474 // will result in duplicating the text of the chunk in the rendered Markdown.
1475 request_assistant_message_id =
1476 Some(thread.insert_assistant_message(
1477 vec![MessageSegment::Thinking {
1478 text: chunk.to_string(),
1479 signature,
1480 }],
1481 cx,
1482 ));
1483 };
1484 }
1485 }
1486 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1487 let last_assistant_message_id = request_assistant_message_id
1488 .unwrap_or_else(|| {
1489 let new_assistant_message_id =
1490 thread.insert_assistant_message(vec![], cx);
1491 request_assistant_message_id =
1492 Some(new_assistant_message_id);
1493 new_assistant_message_id
1494 });
1495
1496 let tool_use_id = tool_use.id.clone();
1497 let streamed_input = if tool_use.is_input_complete {
1498 None
1499 } else {
1500 Some((&tool_use.input).clone())
1501 };
1502
1503 let ui_text = thread.tool_use.request_tool_use(
1504 last_assistant_message_id,
1505 tool_use,
1506 tool_use_metadata.clone(),
1507 cx,
1508 );
1509
1510 if let Some(input) = streamed_input {
1511 cx.emit(ThreadEvent::StreamedToolUse {
1512 tool_use_id,
1513 ui_text,
1514 input,
1515 });
1516 }
1517 }
1518 LanguageModelCompletionEvent::QueueUpdate(status) => {
1519 if let Some(completion) = thread
1520 .pending_completions
1521 .iter_mut()
1522 .find(|completion| completion.id == pending_completion_id)
1523 {
1524 let queue_state = match status {
1525 language_model::CompletionRequestStatus::Queued {
1526 position,
1527 } => Some(QueueState::Queued { position }),
1528 language_model::CompletionRequestStatus::Started => {
1529 Some(QueueState::Started)
1530 }
1531 language_model::CompletionRequestStatus::ToolUseLimitReached => {
1532 thread.tool_use_limit_reached = true;
1533 None
1534 }
1535 };
1536
1537 if let Some(queue_state) = queue_state {
1538 completion.queue_state = queue_state;
1539 }
1540 }
1541 }
1542 }
1543
1544 thread.touch_updated_at();
1545 cx.emit(ThreadEvent::StreamedCompletion);
1546 cx.notify();
1547
1548 thread.auto_capture_telemetry(cx);
1549 Ok(())
1550 })??;
1551
1552 smol::future::yield_now().await;
1553 }
1554
1555 thread.update(cx, |thread, cx| {
1556 thread.last_received_chunk_at = None;
1557 thread
1558 .pending_completions
1559 .retain(|completion| completion.id != pending_completion_id);
1560
1561 // If there is a response without tool use, summarize the message. Otherwise,
1562 // allow two tool uses before summarizing.
1563 if thread.summary.is_none()
1564 && thread.messages.len() >= 2
1565 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1566 {
1567 thread.summarize(cx);
1568 }
1569 })?;
1570
1571 anyhow::Ok(stop_reason)
1572 };
1573
1574 let result = stream_completion.await;
1575
1576 thread
1577 .update(cx, |thread, cx| {
1578 thread.finalize_pending_checkpoint(cx);
1579 match result.as_ref() {
1580 Ok(stop_reason) => match stop_reason {
1581 StopReason::ToolUse => {
1582 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1583 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1584 }
1585 StopReason::EndTurn => {}
1586 StopReason::MaxTokens => {}
1587 },
1588 Err(error) => {
1589 if error.is::<PaymentRequiredError>() {
1590 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1591 } else if error.is::<MaxMonthlySpendReachedError>() {
1592 cx.emit(ThreadEvent::ShowError(
1593 ThreadError::MaxMonthlySpendReached,
1594 ));
1595 } else if let Some(error) =
1596 error.downcast_ref::<ModelRequestLimitReachedError>()
1597 {
1598 cx.emit(ThreadEvent::ShowError(
1599 ThreadError::ModelRequestLimitReached { plan: error.plan },
1600 ));
1601 } else if let Some(known_error) =
1602 error.downcast_ref::<LanguageModelKnownError>()
1603 {
1604 match known_error {
1605 LanguageModelKnownError::ContextWindowLimitExceeded {
1606 tokens,
1607 } => {
1608 thread.exceeded_window_error = Some(ExceededWindowError {
1609 model_id: model.id(),
1610 token_count: *tokens,
1611 });
1612 cx.notify();
1613 }
1614 }
1615 } else {
1616 let error_message = error
1617 .chain()
1618 .map(|err| err.to_string())
1619 .collect::<Vec<_>>()
1620 .join("\n");
1621 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1622 header: "Error interacting with language model".into(),
1623 message: SharedString::from(error_message.clone()),
1624 }));
1625 }
1626
1627 thread.cancel_last_completion(window, cx);
1628 }
1629 }
1630 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1631
1632 if let Some((request_callback, (request, response_events))) = thread
1633 .request_callback
1634 .as_mut()
1635 .zip(request_callback_parameters.as_ref())
1636 {
1637 request_callback(request, response_events);
1638 }
1639
1640 thread.auto_capture_telemetry(cx);
1641
1642 if let Ok(initial_usage) = initial_token_usage {
1643 let usage = thread.cumulative_token_usage - initial_usage;
1644
1645 telemetry::event!(
1646 "Assistant Thread Completion",
1647 thread_id = thread.id().to_string(),
1648 prompt_id = prompt_id,
1649 model = model.telemetry_id(),
1650 model_provider = model.provider_id().to_string(),
1651 input_tokens = usage.input_tokens,
1652 output_tokens = usage.output_tokens,
1653 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1654 cache_read_input_tokens = usage.cache_read_input_tokens,
1655 );
1656 }
1657 })
1658 .ok();
1659 });
1660
1661 self.pending_completions.push(PendingCompletion {
1662 id: pending_completion_id,
1663 queue_state: QueueState::Sending,
1664 _task: task,
1665 });
1666 }
1667
1668 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1669 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1670 return;
1671 };
1672
1673 if !model.provider.is_authenticated(cx) {
1674 return;
1675 }
1676
1677 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1678 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1679 If the conversation is about a specific subject, include it in the title. \
1680 Be descriptive. DO NOT speak in the first person.";
1681
1682 let request = self.to_summarize_request(added_user_message.into());
1683
1684 self.pending_summary = cx.spawn(async move |this, cx| {
1685 async move {
1686 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1687 let (mut messages, usage) = stream.await?;
1688
1689 if let Some(usage) = usage {
1690 this.update(cx, |_thread, cx| {
1691 cx.emit(ThreadEvent::UsageUpdated(usage));
1692 })
1693 .ok();
1694 }
1695
1696 let mut new_summary = String::new();
1697 while let Some(message) = messages.stream.next().await {
1698 let text = message?;
1699 let mut lines = text.lines();
1700 new_summary.extend(lines.next());
1701
1702 // Stop if the LLM generated multiple lines.
1703 if lines.next().is_some() {
1704 break;
1705 }
1706 }
1707
1708 this.update(cx, |this, cx| {
1709 if !new_summary.is_empty() {
1710 this.summary = Some(new_summary.into());
1711 }
1712
1713 cx.emit(ThreadEvent::SummaryGenerated);
1714 })?;
1715
1716 anyhow::Ok(())
1717 }
1718 .log_err()
1719 .await
1720 });
1721 }
1722
1723 pub fn start_generating_detailed_summary_if_needed(
1724 &mut self,
1725 thread_store: WeakEntity<ThreadStore>,
1726 cx: &mut Context<Self>,
1727 ) {
1728 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1729 return;
1730 };
1731
1732 match &*self.detailed_summary_rx.borrow() {
1733 DetailedSummaryState::Generating { message_id, .. }
1734 | DetailedSummaryState::Generated { message_id, .. }
1735 if *message_id == last_message_id =>
1736 {
1737 // Already up-to-date
1738 return;
1739 }
1740 _ => {}
1741 }
1742
1743 let Some(ConfiguredModel { model, provider }) =
1744 LanguageModelRegistry::read_global(cx).thread_summary_model()
1745 else {
1746 return;
1747 };
1748
1749 if !provider.is_authenticated(cx) {
1750 return;
1751 }
1752
1753 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1754 1. A brief overview of what was discussed\n\
1755 2. Key facts or information discovered\n\
1756 3. Outcomes or conclusions reached\n\
1757 4. Any action items or next steps if any\n\
1758 Format it in Markdown with headings and bullet points.";
1759
1760 let request = self.to_summarize_request(added_user_message.into());
1761
1762 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1763 message_id: last_message_id,
1764 };
1765
1766 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1767 // be better to allow the old task to complete, but this would require logic for choosing
1768 // which result to prefer (the old task could complete after the new one, resulting in a
1769 // stale summary).
1770 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1771 let stream = model.stream_completion_text(request, &cx);
1772 let Some(mut messages) = stream.await.log_err() else {
1773 thread
1774 .update(cx, |thread, _cx| {
1775 *thread.detailed_summary_tx.borrow_mut() =
1776 DetailedSummaryState::NotGenerated;
1777 })
1778 .ok()?;
1779 return None;
1780 };
1781
1782 let mut new_detailed_summary = String::new();
1783
1784 while let Some(chunk) = messages.stream.next().await {
1785 if let Some(chunk) = chunk.log_err() {
1786 new_detailed_summary.push_str(&chunk);
1787 }
1788 }
1789
1790 thread
1791 .update(cx, |thread, _cx| {
1792 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1793 text: new_detailed_summary.into(),
1794 message_id: last_message_id,
1795 };
1796 })
1797 .ok()?;
1798
1799 // Save thread so its summary can be reused later
1800 if let Some(thread) = thread.upgrade() {
1801 if let Ok(Ok(save_task)) = cx.update(|cx| {
1802 thread_store
1803 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1804 }) {
1805 save_task.await.log_err();
1806 }
1807 }
1808
1809 Some(())
1810 });
1811 }
1812
1813 pub async fn wait_for_detailed_summary_or_text(
1814 this: &Entity<Self>,
1815 cx: &mut AsyncApp,
1816 ) -> Option<SharedString> {
1817 let mut detailed_summary_rx = this
1818 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1819 .ok()?;
1820 loop {
1821 match detailed_summary_rx.recv().await? {
1822 DetailedSummaryState::Generating { .. } => {}
1823 DetailedSummaryState::NotGenerated => {
1824 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1825 }
1826 DetailedSummaryState::Generated { text, .. } => return Some(text),
1827 }
1828 }
1829 }
1830
1831 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1832 self.detailed_summary_rx
1833 .borrow()
1834 .text()
1835 .unwrap_or_else(|| self.text().into())
1836 }
1837
1838 pub fn is_generating_detailed_summary(&self) -> bool {
1839 matches!(
1840 &*self.detailed_summary_rx.borrow(),
1841 DetailedSummaryState::Generating { .. }
1842 )
1843 }
1844
1845 pub fn use_pending_tools(
1846 &mut self,
1847 window: Option<AnyWindowHandle>,
1848 cx: &mut Context<Self>,
1849 model: Arc<dyn LanguageModel>,
1850 ) -> Vec<PendingToolUse> {
1851 self.auto_capture_telemetry(cx);
1852 let request = self.to_completion_request(model, cx);
1853 let messages = Arc::new(request.messages);
1854 let pending_tool_uses = self
1855 .tool_use
1856 .pending_tool_uses()
1857 .into_iter()
1858 .filter(|tool_use| tool_use.status.is_idle())
1859 .cloned()
1860 .collect::<Vec<_>>();
1861
1862 for tool_use in pending_tool_uses.iter() {
1863 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1864 if tool.needs_confirmation(&tool_use.input, cx)
1865 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1866 {
1867 self.tool_use.confirm_tool_use(
1868 tool_use.id.clone(),
1869 tool_use.ui_text.clone(),
1870 tool_use.input.clone(),
1871 messages.clone(),
1872 tool,
1873 );
1874 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1875 } else {
1876 self.run_tool(
1877 tool_use.id.clone(),
1878 tool_use.ui_text.clone(),
1879 tool_use.input.clone(),
1880 &messages,
1881 tool,
1882 window,
1883 cx,
1884 );
1885 }
1886 }
1887 }
1888
1889 pending_tool_uses
1890 }
1891
1892 pub fn receive_invalid_tool_json(
1893 &mut self,
1894 tool_use_id: LanguageModelToolUseId,
1895 tool_name: Arc<str>,
1896 invalid_json: Arc<str>,
1897 error: String,
1898 window: Option<AnyWindowHandle>,
1899 cx: &mut Context<Thread>,
1900 ) {
1901 log::error!("The model returned invalid input JSON: {invalid_json}");
1902
1903 let pending_tool_use = self.tool_use.insert_tool_output(
1904 tool_use_id.clone(),
1905 tool_name,
1906 Err(anyhow!("Error parsing input JSON: {error}")),
1907 self.configured_model.as_ref(),
1908 );
1909 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1910 pending_tool_use.ui_text.clone()
1911 } else {
1912 log::error!(
1913 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1914 );
1915 format!("Unknown tool {}", tool_use_id).into()
1916 };
1917
1918 cx.emit(ThreadEvent::InvalidToolInput {
1919 tool_use_id: tool_use_id.clone(),
1920 ui_text,
1921 invalid_input_json: invalid_json,
1922 });
1923
1924 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1925 }
1926
1927 pub fn run_tool(
1928 &mut self,
1929 tool_use_id: LanguageModelToolUseId,
1930 ui_text: impl Into<SharedString>,
1931 input: serde_json::Value,
1932 messages: &[LanguageModelRequestMessage],
1933 tool: Arc<dyn Tool>,
1934 window: Option<AnyWindowHandle>,
1935 cx: &mut Context<Thread>,
1936 ) {
1937 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1938 self.tool_use
1939 .run_pending_tool(tool_use_id, ui_text.into(), task);
1940 }
1941
1942 fn spawn_tool_use(
1943 &mut self,
1944 tool_use_id: LanguageModelToolUseId,
1945 messages: &[LanguageModelRequestMessage],
1946 input: serde_json::Value,
1947 tool: Arc<dyn Tool>,
1948 window: Option<AnyWindowHandle>,
1949 cx: &mut Context<Thread>,
1950 ) -> Task<()> {
1951 let tool_name: Arc<str> = tool.name().into();
1952
1953 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1954 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1955 } else {
1956 tool.run(
1957 input,
1958 messages,
1959 self.project.clone(),
1960 self.action_log.clone(),
1961 window,
1962 cx,
1963 )
1964 };
1965
1966 // Store the card separately if it exists
1967 if let Some(card) = tool_result.card.clone() {
1968 self.tool_use
1969 .insert_tool_result_card(tool_use_id.clone(), card);
1970 }
1971
1972 cx.spawn({
1973 async move |thread: WeakEntity<Thread>, cx| {
1974 let output = tool_result.output.await;
1975
1976 thread
1977 .update(cx, |thread, cx| {
1978 let pending_tool_use = thread.tool_use.insert_tool_output(
1979 tool_use_id.clone(),
1980 tool_name,
1981 output,
1982 thread.configured_model.as_ref(),
1983 );
1984 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1985 })
1986 .ok();
1987 }
1988 })
1989 }
1990
1991 fn tool_finished(
1992 &mut self,
1993 tool_use_id: LanguageModelToolUseId,
1994 pending_tool_use: Option<PendingToolUse>,
1995 canceled: bool,
1996 window: Option<AnyWindowHandle>,
1997 cx: &mut Context<Self>,
1998 ) {
1999 if self.all_tools_finished() {
2000 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2001 if !canceled {
2002 self.send_to_model(model.clone(), window, cx);
2003 }
2004 self.auto_capture_telemetry(cx);
2005 }
2006 }
2007
2008 cx.emit(ThreadEvent::ToolFinished {
2009 tool_use_id,
2010 pending_tool_use,
2011 });
2012 }
2013
2014 /// Cancels the last pending completion, if there are any pending.
2015 ///
2016 /// Returns whether a completion was canceled.
2017 pub fn cancel_last_completion(
2018 &mut self,
2019 window: Option<AnyWindowHandle>,
2020 cx: &mut Context<Self>,
2021 ) -> bool {
2022 let mut canceled = self.pending_completions.pop().is_some();
2023
2024 for pending_tool_use in self.tool_use.cancel_pending() {
2025 canceled = true;
2026 self.tool_finished(
2027 pending_tool_use.id.clone(),
2028 Some(pending_tool_use),
2029 true,
2030 window,
2031 cx,
2032 );
2033 }
2034
2035 self.finalize_pending_checkpoint(cx);
2036
2037 if canceled {
2038 cx.emit(ThreadEvent::CompletionCanceled);
2039 }
2040
2041 canceled
2042 }
2043
2044 /// Signals that any in-progress editing should be canceled.
2045 ///
2046 /// This method is used to notify listeners (like ActiveThread) that
2047 /// they should cancel any editing operations.
2048 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2049 cx.emit(ThreadEvent::CancelEditing);
2050 }
2051
2052 pub fn feedback(&self) -> Option<ThreadFeedback> {
2053 self.feedback
2054 }
2055
2056 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2057 self.message_feedback.get(&message_id).copied()
2058 }
2059
2060 pub fn report_message_feedback(
2061 &mut self,
2062 message_id: MessageId,
2063 feedback: ThreadFeedback,
2064 cx: &mut Context<Self>,
2065 ) -> Task<Result<()>> {
2066 if self.message_feedback.get(&message_id) == Some(&feedback) {
2067 return Task::ready(Ok(()));
2068 }
2069
2070 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2071 let serialized_thread = self.serialize(cx);
2072 let thread_id = self.id().clone();
2073 let client = self.project.read(cx).client();
2074
2075 let enabled_tool_names: Vec<String> = self
2076 .tools()
2077 .read(cx)
2078 .enabled_tools(cx)
2079 .iter()
2080 .map(|tool| tool.name().to_string())
2081 .collect();
2082
2083 self.message_feedback.insert(message_id, feedback);
2084
2085 cx.notify();
2086
2087 let message_content = self
2088 .message(message_id)
2089 .map(|msg| msg.to_string())
2090 .unwrap_or_default();
2091
2092 cx.background_spawn(async move {
2093 let final_project_snapshot = final_project_snapshot.await;
2094 let serialized_thread = serialized_thread.await?;
2095 let thread_data =
2096 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2097
2098 let rating = match feedback {
2099 ThreadFeedback::Positive => "positive",
2100 ThreadFeedback::Negative => "negative",
2101 };
2102 telemetry::event!(
2103 "Assistant Thread Rated",
2104 rating,
2105 thread_id,
2106 enabled_tool_names,
2107 message_id = message_id.0,
2108 message_content,
2109 thread_data,
2110 final_project_snapshot
2111 );
2112 client.telemetry().flush_events().await;
2113
2114 Ok(())
2115 })
2116 }
2117
2118 pub fn report_feedback(
2119 &mut self,
2120 feedback: ThreadFeedback,
2121 cx: &mut Context<Self>,
2122 ) -> Task<Result<()>> {
2123 let last_assistant_message_id = self
2124 .messages
2125 .iter()
2126 .rev()
2127 .find(|msg| msg.role == Role::Assistant)
2128 .map(|msg| msg.id);
2129
2130 if let Some(message_id) = last_assistant_message_id {
2131 self.report_message_feedback(message_id, feedback, cx)
2132 } else {
2133 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2134 let serialized_thread = self.serialize(cx);
2135 let thread_id = self.id().clone();
2136 let client = self.project.read(cx).client();
2137 self.feedback = Some(feedback);
2138 cx.notify();
2139
2140 cx.background_spawn(async move {
2141 let final_project_snapshot = final_project_snapshot.await;
2142 let serialized_thread = serialized_thread.await?;
2143 let thread_data = serde_json::to_value(serialized_thread)
2144 .unwrap_or_else(|_| serde_json::Value::Null);
2145
2146 let rating = match feedback {
2147 ThreadFeedback::Positive => "positive",
2148 ThreadFeedback::Negative => "negative",
2149 };
2150 telemetry::event!(
2151 "Assistant Thread Rated",
2152 rating,
2153 thread_id,
2154 thread_data,
2155 final_project_snapshot
2156 );
2157 client.telemetry().flush_events().await;
2158
2159 Ok(())
2160 })
2161 }
2162 }
2163
2164 /// Create a snapshot of the current project state including git information and unsaved buffers.
2165 fn project_snapshot(
2166 project: Entity<Project>,
2167 cx: &mut Context<Self>,
2168 ) -> Task<Arc<ProjectSnapshot>> {
2169 let git_store = project.read(cx).git_store().clone();
2170 let worktree_snapshots: Vec<_> = project
2171 .read(cx)
2172 .visible_worktrees(cx)
2173 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2174 .collect();
2175
2176 cx.spawn(async move |_, cx| {
2177 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2178
2179 let mut unsaved_buffers = Vec::new();
2180 cx.update(|app_cx| {
2181 let buffer_store = project.read(app_cx).buffer_store();
2182 for buffer_handle in buffer_store.read(app_cx).buffers() {
2183 let buffer = buffer_handle.read(app_cx);
2184 if buffer.is_dirty() {
2185 if let Some(file) = buffer.file() {
2186 let path = file.path().to_string_lossy().to_string();
2187 unsaved_buffers.push(path);
2188 }
2189 }
2190 }
2191 })
2192 .ok();
2193
2194 Arc::new(ProjectSnapshot {
2195 worktree_snapshots,
2196 unsaved_buffer_paths: unsaved_buffers,
2197 timestamp: Utc::now(),
2198 })
2199 })
2200 }
2201
2202 fn worktree_snapshot(
2203 worktree: Entity<project::Worktree>,
2204 git_store: Entity<GitStore>,
2205 cx: &App,
2206 ) -> Task<WorktreeSnapshot> {
2207 cx.spawn(async move |cx| {
2208 // Get worktree path and snapshot
2209 let worktree_info = cx.update(|app_cx| {
2210 let worktree = worktree.read(app_cx);
2211 let path = worktree.abs_path().to_string_lossy().to_string();
2212 let snapshot = worktree.snapshot();
2213 (path, snapshot)
2214 });
2215
2216 let Ok((worktree_path, _snapshot)) = worktree_info else {
2217 return WorktreeSnapshot {
2218 worktree_path: String::new(),
2219 git_state: None,
2220 };
2221 };
2222
2223 let git_state = git_store
2224 .update(cx, |git_store, cx| {
2225 git_store
2226 .repositories()
2227 .values()
2228 .find(|repo| {
2229 repo.read(cx)
2230 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2231 .is_some()
2232 })
2233 .cloned()
2234 })
2235 .ok()
2236 .flatten()
2237 .map(|repo| {
2238 repo.update(cx, |repo, _| {
2239 let current_branch =
2240 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2241 repo.send_job(None, |state, _| async move {
2242 let RepositoryState::Local { backend, .. } = state else {
2243 return GitState {
2244 remote_url: None,
2245 head_sha: None,
2246 current_branch,
2247 diff: None,
2248 };
2249 };
2250
2251 let remote_url = backend.remote_url("origin");
2252 let head_sha = backend.head_sha().await;
2253 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2254
2255 GitState {
2256 remote_url,
2257 head_sha,
2258 current_branch,
2259 diff,
2260 }
2261 })
2262 })
2263 });
2264
2265 let git_state = match git_state {
2266 Some(git_state) => match git_state.ok() {
2267 Some(git_state) => git_state.await.ok(),
2268 None => None,
2269 },
2270 None => None,
2271 };
2272
2273 WorktreeSnapshot {
2274 worktree_path,
2275 git_state,
2276 }
2277 })
2278 }
2279
2280 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2281 let mut markdown = Vec::new();
2282
2283 if let Some(summary) = self.summary() {
2284 writeln!(markdown, "# {summary}\n")?;
2285 };
2286
2287 for message in self.messages() {
2288 writeln!(
2289 markdown,
2290 "## {role}\n",
2291 role = match message.role {
2292 Role::User => "User",
2293 Role::Assistant => "Assistant",
2294 Role::System => "System",
2295 }
2296 )?;
2297
2298 if !message.loaded_context.text.is_empty() {
2299 writeln!(markdown, "{}", message.loaded_context.text)?;
2300 }
2301
2302 if !message.loaded_context.images.is_empty() {
2303 writeln!(
2304 markdown,
2305 "\n{} images attached as context.\n",
2306 message.loaded_context.images.len()
2307 )?;
2308 }
2309
2310 for segment in &message.segments {
2311 match segment {
2312 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2313 MessageSegment::Thinking { text, .. } => {
2314 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2315 }
2316 MessageSegment::RedactedThinking(_) => {}
2317 }
2318 }
2319
2320 for tool_use in self.tool_uses_for_message(message.id, cx) {
2321 writeln!(
2322 markdown,
2323 "**Use Tool: {} ({})**",
2324 tool_use.name, tool_use.id
2325 )?;
2326 writeln!(markdown, "```json")?;
2327 writeln!(
2328 markdown,
2329 "{}",
2330 serde_json::to_string_pretty(&tool_use.input)?
2331 )?;
2332 writeln!(markdown, "```")?;
2333 }
2334
2335 for tool_result in self.tool_results_for_message(message.id) {
2336 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2337 if tool_result.is_error {
2338 write!(markdown, " (Error)")?;
2339 }
2340
2341 writeln!(markdown, "**\n")?;
2342 writeln!(markdown, "{}", tool_result.content)?;
2343 }
2344 }
2345
2346 Ok(String::from_utf8_lossy(&markdown).to_string())
2347 }
2348
2349 pub fn keep_edits_in_range(
2350 &mut self,
2351 buffer: Entity<language::Buffer>,
2352 buffer_range: Range<language::Anchor>,
2353 cx: &mut Context<Self>,
2354 ) {
2355 self.action_log.update(cx, |action_log, cx| {
2356 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2357 });
2358 }
2359
2360 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2361 self.action_log
2362 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2363 }
2364
2365 pub fn reject_edits_in_ranges(
2366 &mut self,
2367 buffer: Entity<language::Buffer>,
2368 buffer_ranges: Vec<Range<language::Anchor>>,
2369 cx: &mut Context<Self>,
2370 ) -> Task<Result<()>> {
2371 self.action_log.update(cx, |action_log, cx| {
2372 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2373 })
2374 }
2375
2376 pub fn action_log(&self) -> &Entity<ActionLog> {
2377 &self.action_log
2378 }
2379
2380 pub fn project(&self) -> &Entity<Project> {
2381 &self.project
2382 }
2383
2384 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2385 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2386 return;
2387 }
2388
2389 let now = Instant::now();
2390 if let Some(last) = self.last_auto_capture_at {
2391 if now.duration_since(last).as_secs() < 10 {
2392 return;
2393 }
2394 }
2395
2396 self.last_auto_capture_at = Some(now);
2397
2398 let thread_id = self.id().clone();
2399 let github_login = self
2400 .project
2401 .read(cx)
2402 .user_store()
2403 .read(cx)
2404 .current_user()
2405 .map(|user| user.github_login.clone());
2406 let client = self.project.read(cx).client().clone();
2407 let serialize_task = self.serialize(cx);
2408
2409 cx.background_executor()
2410 .spawn(async move {
2411 if let Ok(serialized_thread) = serialize_task.await {
2412 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2413 telemetry::event!(
2414 "Agent Thread Auto-Captured",
2415 thread_id = thread_id.to_string(),
2416 thread_data = thread_data,
2417 auto_capture_reason = "tracked_user",
2418 github_login = github_login
2419 );
2420
2421 client.telemetry().flush_events().await;
2422 }
2423 }
2424 })
2425 .detach();
2426 }
2427
2428 pub fn cumulative_token_usage(&self) -> TokenUsage {
2429 self.cumulative_token_usage
2430 }
2431
2432 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2433 let Some(model) = self.configured_model.as_ref() else {
2434 return TotalTokenUsage::default();
2435 };
2436
2437 let max = model.model.max_token_count();
2438
2439 let index = self
2440 .messages
2441 .iter()
2442 .position(|msg| msg.id == message_id)
2443 .unwrap_or(0);
2444
2445 if index == 0 {
2446 return TotalTokenUsage { total: 0, max };
2447 }
2448
2449 let token_usage = &self
2450 .request_token_usage
2451 .get(index - 1)
2452 .cloned()
2453 .unwrap_or_default();
2454
2455 TotalTokenUsage {
2456 total: token_usage.total_tokens() as usize,
2457 max,
2458 }
2459 }
2460
2461 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2462 let model = self.configured_model.as_ref()?;
2463
2464 let max = model.model.max_token_count();
2465
2466 if let Some(exceeded_error) = &self.exceeded_window_error {
2467 if model.model.id() == exceeded_error.model_id {
2468 return Some(TotalTokenUsage {
2469 total: exceeded_error.token_count,
2470 max,
2471 });
2472 }
2473 }
2474
2475 let total = self
2476 .token_usage_at_last_message()
2477 .unwrap_or_default()
2478 .total_tokens() as usize;
2479
2480 Some(TotalTokenUsage { total, max })
2481 }
2482
2483 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2484 self.request_token_usage
2485 .get(self.messages.len().saturating_sub(1))
2486 .or_else(|| self.request_token_usage.last())
2487 .cloned()
2488 }
2489
2490 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2491 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2492 self.request_token_usage
2493 .resize(self.messages.len(), placeholder);
2494
2495 if let Some(last) = self.request_token_usage.last_mut() {
2496 *last = token_usage;
2497 }
2498 }
2499
2500 pub fn deny_tool_use(
2501 &mut self,
2502 tool_use_id: LanguageModelToolUseId,
2503 tool_name: Arc<str>,
2504 window: Option<AnyWindowHandle>,
2505 cx: &mut Context<Self>,
2506 ) {
2507 let err = Err(anyhow::anyhow!(
2508 "Permission to run tool action denied by user"
2509 ));
2510
2511 self.tool_use.insert_tool_output(
2512 tool_use_id.clone(),
2513 tool_name,
2514 err,
2515 self.configured_model.as_ref(),
2516 );
2517 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2518 }
2519}
2520
2521#[derive(Debug, Clone, Error)]
2522pub enum ThreadError {
2523 #[error("Payment required")]
2524 PaymentRequired,
2525 #[error("Max monthly spend reached")]
2526 MaxMonthlySpendReached,
2527 #[error("Model request limit reached")]
2528 ModelRequestLimitReached { plan: Plan },
2529 #[error("Message {header}: {message}")]
2530 Message {
2531 header: SharedString,
2532 message: SharedString,
2533 },
2534}
2535
2536#[derive(Debug, Clone)]
2537pub enum ThreadEvent {
2538 ShowError(ThreadError),
2539 UsageUpdated(RequestUsage),
2540 StreamedCompletion,
2541 ReceivedTextChunk,
2542 NewRequest,
2543 StreamedAssistantText(MessageId, String),
2544 StreamedAssistantThinking(MessageId, String),
2545 StreamedToolUse {
2546 tool_use_id: LanguageModelToolUseId,
2547 ui_text: Arc<str>,
2548 input: serde_json::Value,
2549 },
2550 InvalidToolInput {
2551 tool_use_id: LanguageModelToolUseId,
2552 ui_text: Arc<str>,
2553 invalid_input_json: Arc<str>,
2554 },
2555 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2556 MessageAdded(MessageId),
2557 MessageEdited(MessageId),
2558 MessageDeleted(MessageId),
2559 SummaryGenerated,
2560 SummaryChanged,
2561 UsePendingTools {
2562 tool_uses: Vec<PendingToolUse>,
2563 },
2564 ToolFinished {
2565 #[allow(unused)]
2566 tool_use_id: LanguageModelToolUseId,
2567 /// The pending tool use that corresponds to this tool.
2568 pending_tool_use: Option<PendingToolUse>,
2569 },
2570 CheckpointChanged,
2571 ToolConfirmationNeeded,
2572 CancelEditing,
2573 CompletionCanceled,
2574}
2575
2576impl EventEmitter<ThreadEvent> for Thread {}
2577
2578struct PendingCompletion {
2579 id: usize,
2580 queue_state: QueueState,
2581 _task: Task<()>,
2582}
2583
2584#[cfg(test)]
2585mod tests {
2586 use super::*;
2587 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2588 use assistant_settings::AssistantSettings;
2589 use assistant_tool::ToolRegistry;
2590 use context_server::ContextServerSettings;
2591 use editor::EditorSettings;
2592 use gpui::TestAppContext;
2593 use language_model::fake_provider::FakeLanguageModel;
2594 use project::{FakeFs, Project};
2595 use prompt_store::PromptBuilder;
2596 use serde_json::json;
2597 use settings::{Settings, SettingsStore};
2598 use std::sync::Arc;
2599 use theme::ThemeSettings;
2600 use util::path;
2601 use workspace::Workspace;
2602
2603 #[gpui::test]
2604 async fn test_message_with_context(cx: &mut TestAppContext) {
2605 init_test_settings(cx);
2606
2607 let project = create_test_project(
2608 cx,
2609 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2610 )
2611 .await;
2612
2613 let (_workspace, _thread_store, thread, context_store, model) =
2614 setup_test_environment(cx, project.clone()).await;
2615
2616 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2617 .await
2618 .unwrap();
2619
2620 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2621 let loaded_context = cx
2622 .update(|cx| load_context(vec![context], &project, &None, cx))
2623 .await;
2624
2625 // Insert user message with context
2626 let message_id = thread.update(cx, |thread, cx| {
2627 thread.insert_user_message(
2628 "Please explain this code",
2629 loaded_context,
2630 None,
2631 Vec::new(),
2632 cx,
2633 )
2634 });
2635
2636 // Check content and context in message object
2637 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2638
2639 // Use different path format strings based on platform for the test
2640 #[cfg(windows)]
2641 let path_part = r"test\code.rs";
2642 #[cfg(not(windows))]
2643 let path_part = "test/code.rs";
2644
2645 let expected_context = format!(
2646 r#"
2647<context>
2648The following items were attached by the user. They are up-to-date and don't need to be re-read.
2649
2650<files>
2651```rs {path_part}
2652fn main() {{
2653 println!("Hello, world!");
2654}}
2655```
2656</files>
2657</context>
2658"#
2659 );
2660
2661 assert_eq!(message.role, Role::User);
2662 assert_eq!(message.segments.len(), 1);
2663 assert_eq!(
2664 message.segments[0],
2665 MessageSegment::Text("Please explain this code".to_string())
2666 );
2667 assert_eq!(message.loaded_context.text, expected_context);
2668
2669 // Check message in request
2670 let request = thread.update(cx, |thread, cx| {
2671 thread.to_completion_request(model.clone(), cx)
2672 });
2673
2674 assert_eq!(request.messages.len(), 2);
2675 let expected_full_message = format!("{}Please explain this code", expected_context);
2676 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2677 }
2678
2679 #[gpui::test]
2680 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2681 init_test_settings(cx);
2682
2683 let project = create_test_project(
2684 cx,
2685 json!({
2686 "file1.rs": "fn function1() {}\n",
2687 "file2.rs": "fn function2() {}\n",
2688 "file3.rs": "fn function3() {}\n",
2689 "file4.rs": "fn function4() {}\n",
2690 }),
2691 )
2692 .await;
2693
2694 let (_, _thread_store, thread, context_store, model) =
2695 setup_test_environment(cx, project.clone()).await;
2696
2697 // First message with context 1
2698 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2699 .await
2700 .unwrap();
2701 let new_contexts = context_store.update(cx, |store, cx| {
2702 store.new_context_for_thread(thread.read(cx), None)
2703 });
2704 assert_eq!(new_contexts.len(), 1);
2705 let loaded_context = cx
2706 .update(|cx| load_context(new_contexts, &project, &None, cx))
2707 .await;
2708 let message1_id = thread.update(cx, |thread, cx| {
2709 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2710 });
2711
2712 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2713 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2714 .await
2715 .unwrap();
2716 let new_contexts = context_store.update(cx, |store, cx| {
2717 store.new_context_for_thread(thread.read(cx), None)
2718 });
2719 assert_eq!(new_contexts.len(), 1);
2720 let loaded_context = cx
2721 .update(|cx| load_context(new_contexts, &project, &None, cx))
2722 .await;
2723 let message2_id = thread.update(cx, |thread, cx| {
2724 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2725 });
2726
2727 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2728 //
2729 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2730 .await
2731 .unwrap();
2732 let new_contexts = context_store.update(cx, |store, cx| {
2733 store.new_context_for_thread(thread.read(cx), None)
2734 });
2735 assert_eq!(new_contexts.len(), 1);
2736 let loaded_context = cx
2737 .update(|cx| load_context(new_contexts, &project, &None, cx))
2738 .await;
2739 let message3_id = thread.update(cx, |thread, cx| {
2740 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2741 });
2742
2743 // Check what contexts are included in each message
2744 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2745 (
2746 thread.message(message1_id).unwrap().clone(),
2747 thread.message(message2_id).unwrap().clone(),
2748 thread.message(message3_id).unwrap().clone(),
2749 )
2750 });
2751
2752 // First message should include context 1
2753 assert!(message1.loaded_context.text.contains("file1.rs"));
2754
2755 // Second message should include only context 2 (not 1)
2756 assert!(!message2.loaded_context.text.contains("file1.rs"));
2757 assert!(message2.loaded_context.text.contains("file2.rs"));
2758
2759 // Third message should include only context 3 (not 1 or 2)
2760 assert!(!message3.loaded_context.text.contains("file1.rs"));
2761 assert!(!message3.loaded_context.text.contains("file2.rs"));
2762 assert!(message3.loaded_context.text.contains("file3.rs"));
2763
2764 // Check entire request to make sure all contexts are properly included
2765 let request = thread.update(cx, |thread, cx| {
2766 thread.to_completion_request(model.clone(), cx)
2767 });
2768
2769 // The request should contain all 3 messages
2770 assert_eq!(request.messages.len(), 4);
2771
2772 // Check that the contexts are properly formatted in each message
2773 assert!(request.messages[1].string_contents().contains("file1.rs"));
2774 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2775 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2776
2777 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2778 assert!(request.messages[2].string_contents().contains("file2.rs"));
2779 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2780
2781 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2782 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2783 assert!(request.messages[3].string_contents().contains("file3.rs"));
2784
2785 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2786 .await
2787 .unwrap();
2788 let new_contexts = context_store.update(cx, |store, cx| {
2789 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2790 });
2791 assert_eq!(new_contexts.len(), 3);
2792 let loaded_context = cx
2793 .update(|cx| load_context(new_contexts, &project, &None, cx))
2794 .await
2795 .loaded_context;
2796
2797 assert!(!loaded_context.text.contains("file1.rs"));
2798 assert!(loaded_context.text.contains("file2.rs"));
2799 assert!(loaded_context.text.contains("file3.rs"));
2800 assert!(loaded_context.text.contains("file4.rs"));
2801
2802 let new_contexts = context_store.update(cx, |store, cx| {
2803 // Remove file4.rs
2804 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2805 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2806 });
2807 assert_eq!(new_contexts.len(), 2);
2808 let loaded_context = cx
2809 .update(|cx| load_context(new_contexts, &project, &None, cx))
2810 .await
2811 .loaded_context;
2812
2813 assert!(!loaded_context.text.contains("file1.rs"));
2814 assert!(loaded_context.text.contains("file2.rs"));
2815 assert!(loaded_context.text.contains("file3.rs"));
2816 assert!(!loaded_context.text.contains("file4.rs"));
2817
2818 let new_contexts = context_store.update(cx, |store, cx| {
2819 // Remove file3.rs
2820 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2821 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2822 });
2823 assert_eq!(new_contexts.len(), 1);
2824 let loaded_context = cx
2825 .update(|cx| load_context(new_contexts, &project, &None, cx))
2826 .await
2827 .loaded_context;
2828
2829 assert!(!loaded_context.text.contains("file1.rs"));
2830 assert!(loaded_context.text.contains("file2.rs"));
2831 assert!(!loaded_context.text.contains("file3.rs"));
2832 assert!(!loaded_context.text.contains("file4.rs"));
2833 }
2834
2835 #[gpui::test]
2836 async fn test_message_without_files(cx: &mut TestAppContext) {
2837 init_test_settings(cx);
2838
2839 let project = create_test_project(
2840 cx,
2841 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2842 )
2843 .await;
2844
2845 let (_, _thread_store, thread, _context_store, model) =
2846 setup_test_environment(cx, project.clone()).await;
2847
2848 // Insert user message without any context (empty context vector)
2849 let message_id = thread.update(cx, |thread, cx| {
2850 thread.insert_user_message(
2851 "What is the best way to learn Rust?",
2852 ContextLoadResult::default(),
2853 None,
2854 Vec::new(),
2855 cx,
2856 )
2857 });
2858
2859 // Check content and context in message object
2860 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2861
2862 // Context should be empty when no files are included
2863 assert_eq!(message.role, Role::User);
2864 assert_eq!(message.segments.len(), 1);
2865 assert_eq!(
2866 message.segments[0],
2867 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2868 );
2869 assert_eq!(message.loaded_context.text, "");
2870
2871 // Check message in request
2872 let request = thread.update(cx, |thread, cx| {
2873 thread.to_completion_request(model.clone(), cx)
2874 });
2875
2876 assert_eq!(request.messages.len(), 2);
2877 assert_eq!(
2878 request.messages[1].string_contents(),
2879 "What is the best way to learn Rust?"
2880 );
2881
2882 // Add second message, also without context
2883 let message2_id = thread.update(cx, |thread, cx| {
2884 thread.insert_user_message(
2885 "Are there any good books?",
2886 ContextLoadResult::default(),
2887 None,
2888 Vec::new(),
2889 cx,
2890 )
2891 });
2892
2893 let message2 =
2894 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2895 assert_eq!(message2.loaded_context.text, "");
2896
2897 // Check that both messages appear in the request
2898 let request = thread.update(cx, |thread, cx| {
2899 thread.to_completion_request(model.clone(), cx)
2900 });
2901
2902 assert_eq!(request.messages.len(), 3);
2903 assert_eq!(
2904 request.messages[1].string_contents(),
2905 "What is the best way to learn Rust?"
2906 );
2907 assert_eq!(
2908 request.messages[2].string_contents(),
2909 "Are there any good books?"
2910 );
2911 }
2912
2913 #[gpui::test]
2914 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2915 init_test_settings(cx);
2916
2917 let project = create_test_project(
2918 cx,
2919 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2920 )
2921 .await;
2922
2923 let (_workspace, _thread_store, thread, context_store, model) =
2924 setup_test_environment(cx, project.clone()).await;
2925
2926 // Open buffer and add it to context
2927 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2928 .await
2929 .unwrap();
2930
2931 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2932 let loaded_context = cx
2933 .update(|cx| load_context(vec![context], &project, &None, cx))
2934 .await;
2935
2936 // Insert user message with the buffer as context
2937 thread.update(cx, |thread, cx| {
2938 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2939 });
2940
2941 // Create a request and check that it doesn't have a stale buffer warning yet
2942 let initial_request = thread.update(cx, |thread, cx| {
2943 thread.to_completion_request(model.clone(), cx)
2944 });
2945
2946 // Make sure we don't have a stale file warning yet
2947 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2948 msg.string_contents()
2949 .contains("These files changed since last read:")
2950 });
2951 assert!(
2952 !has_stale_warning,
2953 "Should not have stale buffer warning before buffer is modified"
2954 );
2955
2956 // Modify the buffer
2957 buffer.update(cx, |buffer, cx| {
2958 // Find a position at the end of line 1
2959 buffer.edit(
2960 [(1..1, "\n println!(\"Added a new line\");\n")],
2961 None,
2962 cx,
2963 );
2964 });
2965
2966 // Insert another user message without context
2967 thread.update(cx, |thread, cx| {
2968 thread.insert_user_message(
2969 "What does the code do now?",
2970 ContextLoadResult::default(),
2971 None,
2972 Vec::new(),
2973 cx,
2974 )
2975 });
2976
2977 // Create a new request and check for the stale buffer warning
2978 let new_request = thread.update(cx, |thread, cx| {
2979 thread.to_completion_request(model.clone(), cx)
2980 });
2981
2982 // We should have a stale file warning as the last message
2983 let last_message = new_request
2984 .messages
2985 .last()
2986 .expect("Request should have messages");
2987
2988 // The last message should be the stale buffer notification
2989 assert_eq!(last_message.role, Role::User);
2990
2991 // Check the exact content of the message
2992 let expected_content = "These files changed since last read:\n- code.rs\n";
2993 assert_eq!(
2994 last_message.string_contents(),
2995 expected_content,
2996 "Last message should be exactly the stale buffer notification"
2997 );
2998 }
2999
3000 fn init_test_settings(cx: &mut TestAppContext) {
3001 cx.update(|cx| {
3002 let settings_store = SettingsStore::test(cx);
3003 cx.set_global(settings_store);
3004 language::init(cx);
3005 Project::init_settings(cx);
3006 AssistantSettings::register(cx);
3007 prompt_store::init(cx);
3008 thread_store::init(cx);
3009 workspace::init_settings(cx);
3010 language_model::init_settings(cx);
3011 ThemeSettings::register(cx);
3012 ContextServerSettings::register(cx);
3013 EditorSettings::register(cx);
3014 ToolRegistry::default_global(cx);
3015 });
3016 }
3017
3018 // Helper to create a test project with test files
3019 async fn create_test_project(
3020 cx: &mut TestAppContext,
3021 files: serde_json::Value,
3022 ) -> Entity<Project> {
3023 let fs = FakeFs::new(cx.executor());
3024 fs.insert_tree(path!("/test"), files).await;
3025 Project::test(fs, [path!("/test").as_ref()], cx).await
3026 }
3027
3028 async fn setup_test_environment(
3029 cx: &mut TestAppContext,
3030 project: Entity<Project>,
3031 ) -> (
3032 Entity<Workspace>,
3033 Entity<ThreadStore>,
3034 Entity<Thread>,
3035 Entity<ContextStore>,
3036 Arc<dyn LanguageModel>,
3037 ) {
3038 let (workspace, cx) =
3039 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3040
3041 let thread_store = cx
3042 .update(|_, cx| {
3043 ThreadStore::load(
3044 project.clone(),
3045 cx.new(|_| ToolWorkingSet::default()),
3046 None,
3047 Arc::new(PromptBuilder::new(None).unwrap()),
3048 cx,
3049 )
3050 })
3051 .await
3052 .unwrap();
3053
3054 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3055 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3056
3057 let model = FakeLanguageModel::default();
3058 let model: Arc<dyn LanguageModel> = Arc::new(model);
3059
3060 (workspace, thread_store, thread, context_store, model)
3061 }
3062
3063 async fn add_file_to_context(
3064 project: &Entity<Project>,
3065 context_store: &Entity<ContextStore>,
3066 path: &str,
3067 cx: &mut TestAppContext,
3068 ) -> Result<Entity<language::Buffer>> {
3069 let buffer_path = project
3070 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3071 .unwrap();
3072
3073 let buffer = project
3074 .update(cx, |project, cx| {
3075 project.open_buffer(buffer_path.clone(), cx)
3076 })
3077 .await
3078 .unwrap();
3079
3080 context_store.update(cx, |context_store, cx| {
3081 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3082 });
3083
3084 Ok(buffer)
3085 }
3086}