1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
21 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
22 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
23 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::PromptBuilder;
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37
38use crate::context::{AssistantContext, ContextId, format_context_as_string};
39use crate::thread_store::{
40 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
41 SerializedToolUse, SharedProjectContext,
42};
43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
44
45#[derive(
46 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
47)]
48pub struct ThreadId(Arc<str>);
49
50impl ThreadId {
51 pub fn new() -> Self {
52 Self(Uuid::new_v4().to_string().into())
53 }
54}
55
56impl std::fmt::Display for ThreadId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl From<&str> for ThreadId {
63 fn from(value: &str) -> Self {
64 Self(value.into())
65 }
66}
67
68/// The ID of the user prompt that initiated a request.
69///
70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
72pub struct PromptId(Arc<str>);
73
74impl PromptId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for PromptId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
87pub struct MessageId(pub(crate) usize);
88
89impl MessageId {
90 fn post_inc(&mut self) -> Self {
91 Self(post_inc(&mut self.0))
92 }
93}
94
95/// A message in a [`Thread`].
96#[derive(Debug, Clone)]
97pub struct Message {
98 pub id: MessageId,
99 pub role: Role,
100 pub segments: Vec<MessageSegment>,
101 pub context: String,
102 pub images: Vec<LanguageModelImage>,
103}
104
105impl Message {
106 /// Returns whether the message contains any meaningful text that should be displayed
107 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
108 pub fn should_display_content(&self) -> bool {
109 self.segments.iter().all(|segment| segment.should_display())
110 }
111
112 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
113 if let Some(MessageSegment::Thinking {
114 text: segment,
115 signature: current_signature,
116 }) = self.segments.last_mut()
117 {
118 if let Some(signature) = signature {
119 *current_signature = Some(signature);
120 }
121 segment.push_str(text);
122 } else {
123 self.segments.push(MessageSegment::Thinking {
124 text: text.to_string(),
125 signature,
126 });
127 }
128 }
129
130 pub fn push_text(&mut self, text: &str) {
131 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
132 segment.push_str(text);
133 } else {
134 self.segments.push(MessageSegment::Text(text.to_string()));
135 }
136 }
137
138 pub fn to_string(&self) -> String {
139 let mut result = String::new();
140
141 if !self.context.is_empty() {
142 result.push_str(&self.context);
143 }
144
145 for segment in &self.segments {
146 match segment {
147 MessageSegment::Text(text) => result.push_str(text),
148 MessageSegment::Thinking { text, .. } => {
149 result.push_str("<think>\n");
150 result.push_str(text);
151 result.push_str("\n</think>");
152 }
153 MessageSegment::RedactedThinking(_) => {}
154 }
155 }
156
157 result
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub enum MessageSegment {
163 Text(String),
164 Thinking {
165 text: String,
166 signature: Option<String>,
167 },
168 RedactedThinking(Vec<u8>),
169}
170
171impl MessageSegment {
172 pub fn should_display(&self) -> bool {
173 match self {
174 Self::Text(text) => text.is_empty(),
175 Self::Thinking { text, .. } => text.is_empty(),
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320 remaining_turns: u32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ExceededWindowError {
325 /// Model used when last message exceeded context window
326 model_id: LanguageModelId,
327 /// Token count including last message
328 token_count: usize,
329}
330
331impl Thread {
332 pub fn new(
333 project: Entity<Project>,
334 tools: Entity<ToolWorkingSet>,
335 prompt_builder: Arc<PromptBuilder>,
336 system_prompt: SharedProjectContext,
337 cx: &mut Context<Self>,
338 ) -> Self {
339 Self {
340 id: ThreadId::new(),
341 updated_at: Utc::now(),
342 summary: None,
343 pending_summary: Task::ready(None),
344 detailed_summary_state: DetailedSummaryState::NotGenerated,
345 messages: Vec::new(),
346 next_message_id: MessageId(0),
347 last_prompt_id: PromptId::new(),
348 context: BTreeMap::default(),
349 context_by_message: HashMap::default(),
350 project_context: system_prompt,
351 checkpoints_by_message: HashMap::default(),
352 completion_count: 0,
353 pending_completions: Vec::new(),
354 project: project.clone(),
355 prompt_builder,
356 tools: tools.clone(),
357 last_restore_checkpoint: None,
358 pending_checkpoint: None,
359 tool_use: ToolUseState::new(tools.clone()),
360 action_log: cx.new(|_| ActionLog::new(project.clone())),
361 initial_project_snapshot: {
362 let project_snapshot = Self::project_snapshot(project, cx);
363 cx.foreground_executor()
364 .spawn(async move { Some(project_snapshot.await) })
365 .shared()
366 },
367 request_token_usage: Vec::new(),
368 cumulative_token_usage: TokenUsage::default(),
369 exceeded_window_error: None,
370 feedback: None,
371 message_feedback: HashMap::default(),
372 last_auto_capture_at: None,
373 request_callback: None,
374 remaining_turns: u32::MAX,
375 }
376 }
377
378 pub fn deserialize(
379 id: ThreadId,
380 serialized: SerializedThread,
381 project: Entity<Project>,
382 tools: Entity<ToolWorkingSet>,
383 prompt_builder: Arc<PromptBuilder>,
384 project_context: SharedProjectContext,
385 cx: &mut Context<Self>,
386 ) -> Self {
387 let next_message_id = MessageId(
388 serialized
389 .messages
390 .last()
391 .map(|message| message.id.0 + 1)
392 .unwrap_or(0),
393 );
394 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
395
396 Self {
397 id,
398 updated_at: serialized.updated_at,
399 summary: Some(serialized.summary),
400 pending_summary: Task::ready(None),
401 detailed_summary_state: serialized.detailed_summary_state,
402 messages: serialized
403 .messages
404 .into_iter()
405 .map(|message| Message {
406 id: message.id,
407 role: message.role,
408 segments: message
409 .segments
410 .into_iter()
411 .map(|segment| match segment {
412 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
413 SerializedMessageSegment::Thinking { text, signature } => {
414 MessageSegment::Thinking { text, signature }
415 }
416 SerializedMessageSegment::RedactedThinking { data } => {
417 MessageSegment::RedactedThinking(data)
418 }
419 })
420 .collect(),
421 context: message.context,
422 images: Vec::new(),
423 })
424 .collect(),
425 next_message_id,
426 last_prompt_id: PromptId::new(),
427 context: BTreeMap::default(),
428 context_by_message: HashMap::default(),
429 project_context,
430 checkpoints_by_message: HashMap::default(),
431 completion_count: 0,
432 pending_completions: Vec::new(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 project: project.clone(),
436 prompt_builder,
437 tools,
438 tool_use,
439 action_log: cx.new(|_| ActionLog::new(project)),
440 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
441 request_token_usage: serialized.request_token_usage,
442 cumulative_token_usage: serialized.cumulative_token_usage,
443 exceeded_window_error: None,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 request_callback: None,
448 remaining_turns: u32::MAX,
449 }
450 }
451
452 pub fn set_request_callback(
453 &mut self,
454 callback: impl 'static
455 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
456 ) {
457 self.request_callback = Some(Box::new(callback));
458 }
459
460 pub fn id(&self) -> &ThreadId {
461 &self.id
462 }
463
464 pub fn is_empty(&self) -> bool {
465 self.messages.is_empty()
466 }
467
468 pub fn updated_at(&self) -> DateTime<Utc> {
469 self.updated_at
470 }
471
472 pub fn touch_updated_at(&mut self) {
473 self.updated_at = Utc::now();
474 }
475
476 pub fn advance_prompt_id(&mut self) {
477 self.last_prompt_id = PromptId::new();
478 }
479
480 pub fn summary(&self) -> Option<SharedString> {
481 self.summary.clone()
482 }
483
484 pub fn project_context(&self) -> SharedProjectContext {
485 self.project_context.clone()
486 }
487
488 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
489
490 pub fn summary_or_default(&self) -> SharedString {
491 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
492 }
493
494 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
495 let Some(current_summary) = &self.summary else {
496 // Don't allow setting summary until generated
497 return;
498 };
499
500 let mut new_summary = new_summary.into();
501
502 if new_summary.is_empty() {
503 new_summary = Self::DEFAULT_SUMMARY;
504 }
505
506 if current_summary != &new_summary {
507 self.summary = Some(new_summary);
508 cx.emit(ThreadEvent::SummaryChanged);
509 }
510 }
511
512 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
513 self.latest_detailed_summary()
514 .unwrap_or_else(|| self.text().into())
515 }
516
517 fn latest_detailed_summary(&self) -> Option<SharedString> {
518 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
519 Some(text.clone())
520 } else {
521 None
522 }
523 }
524
525 pub fn message(&self, id: MessageId) -> Option<&Message> {
526 let index = self
527 .messages
528 .binary_search_by(|message| message.id.cmp(&id))
529 .ok()?;
530
531 self.messages.get(index)
532 }
533
534 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
535 self.messages.iter()
536 }
537
538 pub fn is_generating(&self) -> bool {
539 !self.pending_completions.is_empty() || !self.all_tools_finished()
540 }
541
542 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
543 &self.tools
544 }
545
546 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
547 self.tool_use
548 .pending_tool_uses()
549 .into_iter()
550 .find(|tool_use| &tool_use.id == id)
551 }
552
553 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
554 self.tool_use
555 .pending_tool_uses()
556 .into_iter()
557 .filter(|tool_use| tool_use.status.needs_confirmation())
558 }
559
560 pub fn has_pending_tool_uses(&self) -> bool {
561 !self.tool_use.pending_tool_uses().is_empty()
562 }
563
564 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
565 self.checkpoints_by_message.get(&id).cloned()
566 }
567
568 pub fn restore_checkpoint(
569 &mut self,
570 checkpoint: ThreadCheckpoint,
571 cx: &mut Context<Self>,
572 ) -> Task<Result<()>> {
573 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
574 message_id: checkpoint.message_id,
575 });
576 cx.emit(ThreadEvent::CheckpointChanged);
577 cx.notify();
578
579 let git_store = self.project().read(cx).git_store().clone();
580 let restore = git_store.update(cx, |git_store, cx| {
581 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
582 });
583
584 cx.spawn(async move |this, cx| {
585 let result = restore.await;
586 this.update(cx, |this, cx| {
587 if let Err(err) = result.as_ref() {
588 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
589 message_id: checkpoint.message_id,
590 error: err.to_string(),
591 });
592 } else {
593 this.truncate(checkpoint.message_id, cx);
594 this.last_restore_checkpoint = None;
595 }
596 this.pending_checkpoint = None;
597 cx.emit(ThreadEvent::CheckpointChanged);
598 cx.notify();
599 })?;
600 result
601 })
602 }
603
604 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
605 let pending_checkpoint = if self.is_generating() {
606 return;
607 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
608 checkpoint
609 } else {
610 return;
611 };
612
613 let git_store = self.project.read(cx).git_store().clone();
614 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
615 cx.spawn(async move |this, cx| match final_checkpoint.await {
616 Ok(final_checkpoint) => {
617 let equal = git_store
618 .update(cx, |store, cx| {
619 store.compare_checkpoints(
620 pending_checkpoint.git_checkpoint.clone(),
621 final_checkpoint.clone(),
622 cx,
623 )
624 })?
625 .await
626 .unwrap_or(false);
627
628 if !equal {
629 this.update(cx, |this, cx| {
630 this.insert_checkpoint(pending_checkpoint, cx)
631 })?;
632 }
633
634 Ok(())
635 }
636 Err(_) => this.update(cx, |this, cx| {
637 this.insert_checkpoint(pending_checkpoint, cx)
638 }),
639 })
640 .detach();
641 }
642
643 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
644 self.checkpoints_by_message
645 .insert(checkpoint.message_id, checkpoint);
646 cx.emit(ThreadEvent::CheckpointChanged);
647 cx.notify();
648 }
649
650 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
651 self.last_restore_checkpoint.as_ref()
652 }
653
654 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
655 let Some(message_ix) = self
656 .messages
657 .iter()
658 .rposition(|message| message.id == message_id)
659 else {
660 return;
661 };
662 for deleted_message in self.messages.drain(message_ix..) {
663 self.context_by_message.remove(&deleted_message.id);
664 self.checkpoints_by_message.remove(&deleted_message.id);
665 }
666 cx.notify();
667 }
668
669 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
670 self.context_by_message
671 .get(&id)
672 .into_iter()
673 .flat_map(|context| {
674 context
675 .iter()
676 .filter_map(|context_id| self.context.get(&context_id))
677 })
678 }
679
680 pub fn is_turn_end(&self, ix: usize) -> bool {
681 if self.messages.is_empty() {
682 return false;
683 }
684
685 if !self.is_generating() && ix == self.messages.len() - 1 {
686 return true;
687 }
688
689 let Some(message) = self.messages.get(ix) else {
690 return false;
691 };
692
693 if message.role != Role::Assistant {
694 return false;
695 }
696
697 self.messages
698 .get(ix + 1)
699 .and_then(|message| {
700 self.message(message.id)
701 .map(|next_message| next_message.role == Role::User)
702 })
703 .unwrap_or(false)
704 }
705
706 /// Returns whether all of the tool uses have finished running.
707 pub fn all_tools_finished(&self) -> bool {
708 // If the only pending tool uses left are the ones with errors, then
709 // that means that we've finished running all of the pending tools.
710 self.tool_use
711 .pending_tool_uses()
712 .iter()
713 .all(|tool_use| tool_use.status.is_error())
714 }
715
716 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
717 self.tool_use.tool_uses_for_message(id, cx)
718 }
719
720 pub fn tool_results_for_message(
721 &self,
722 assistant_message_id: MessageId,
723 ) -> Vec<&LanguageModelToolResult> {
724 self.tool_use.tool_results_for_message(assistant_message_id)
725 }
726
727 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
728 self.tool_use.tool_result(id)
729 }
730
731 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
732 Some(&self.tool_use.tool_result(id)?.content)
733 }
734
735 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
736 self.tool_use.tool_result_card(id).cloned()
737 }
738
739 /// Filter out contexts that have already been included in previous messages
740 pub fn filter_new_context<'a>(
741 &self,
742 context: impl Iterator<Item = &'a AssistantContext>,
743 ) -> impl Iterator<Item = &'a AssistantContext> {
744 context.filter(|ctx| self.is_context_new(ctx))
745 }
746
747 fn is_context_new(&self, context: &AssistantContext) -> bool {
748 !self.context.contains_key(&context.id())
749 }
750
751 pub fn insert_user_message(
752 &mut self,
753 text: impl Into<String>,
754 context: Vec<AssistantContext>,
755 git_checkpoint: Option<GitStoreCheckpoint>,
756 cx: &mut Context<Self>,
757 ) -> MessageId {
758 let text = text.into();
759
760 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
761
762 let new_context: Vec<_> = context
763 .into_iter()
764 .filter(|ctx| self.is_context_new(ctx))
765 .collect();
766
767 if !new_context.is_empty() {
768 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
769 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
770 message.context = context_string;
771 }
772 }
773
774 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
775 message.images = new_context
776 .iter()
777 .filter_map(|context| {
778 if let AssistantContext::Image(image_context) = context {
779 image_context.image_task.clone().now_or_never().flatten()
780 } else {
781 None
782 }
783 })
784 .collect::<Vec<_>>();
785 }
786
787 self.action_log.update(cx, |log, cx| {
788 // Track all buffers added as context
789 for ctx in &new_context {
790 match ctx {
791 AssistantContext::File(file_ctx) => {
792 log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
793 }
794 AssistantContext::Directory(dir_ctx) => {
795 for context_buffer in &dir_ctx.context_buffers {
796 log.track_buffer(context_buffer.buffer.clone(), cx);
797 }
798 }
799 AssistantContext::Symbol(symbol_ctx) => {
800 log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
801 }
802 AssistantContext::Selection(selection_context) => {
803 log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
804 }
805 AssistantContext::FetchedUrl(_)
806 | AssistantContext::Thread(_)
807 | AssistantContext::Rules(_)
808 | AssistantContext::Image(_) => {}
809 }
810 }
811 });
812 }
813
814 let context_ids = new_context
815 .iter()
816 .map(|context| context.id())
817 .collect::<Vec<_>>();
818 self.context.extend(
819 new_context
820 .into_iter()
821 .map(|context| (context.id(), context)),
822 );
823 self.context_by_message.insert(message_id, context_ids);
824
825 if let Some(git_checkpoint) = git_checkpoint {
826 self.pending_checkpoint = Some(ThreadCheckpoint {
827 message_id,
828 git_checkpoint,
829 });
830 }
831
832 self.auto_capture_telemetry(cx);
833
834 message_id
835 }
836
837 pub fn insert_message(
838 &mut self,
839 role: Role,
840 segments: Vec<MessageSegment>,
841 cx: &mut Context<Self>,
842 ) -> MessageId {
843 let id = self.next_message_id.post_inc();
844 self.messages.push(Message {
845 id,
846 role,
847 segments,
848 context: String::new(),
849 images: Vec::new(),
850 });
851 self.touch_updated_at();
852 cx.emit(ThreadEvent::MessageAdded(id));
853 id
854 }
855
856 pub fn edit_message(
857 &mut self,
858 id: MessageId,
859 new_role: Role,
860 new_segments: Vec<MessageSegment>,
861 cx: &mut Context<Self>,
862 ) -> bool {
863 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
864 return false;
865 };
866 message.role = new_role;
867 message.segments = new_segments;
868 self.touch_updated_at();
869 cx.emit(ThreadEvent::MessageEdited(id));
870 true
871 }
872
873 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
874 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
875 return false;
876 };
877 self.messages.remove(index);
878 self.context_by_message.remove(&id);
879 self.touch_updated_at();
880 cx.emit(ThreadEvent::MessageDeleted(id));
881 true
882 }
883
884 /// Returns the representation of this [`Thread`] in a textual form.
885 ///
886 /// This is the representation we use when attaching a thread as context to another thread.
887 pub fn text(&self) -> String {
888 let mut text = String::new();
889
890 for message in &self.messages {
891 text.push_str(match message.role {
892 language_model::Role::User => "User:",
893 language_model::Role::Assistant => "Assistant:",
894 language_model::Role::System => "System:",
895 });
896 text.push('\n');
897
898 for segment in &message.segments {
899 match segment {
900 MessageSegment::Text(content) => text.push_str(content),
901 MessageSegment::Thinking { text: content, .. } => {
902 text.push_str(&format!("<think>{}</think>", content))
903 }
904 MessageSegment::RedactedThinking(_) => {}
905 }
906 }
907 text.push('\n');
908 }
909
910 text
911 }
912
913 /// Serializes this thread into a format for storage or telemetry.
914 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
915 let initial_project_snapshot = self.initial_project_snapshot.clone();
916 cx.spawn(async move |this, cx| {
917 let initial_project_snapshot = initial_project_snapshot.await;
918 this.read_with(cx, |this, cx| SerializedThread {
919 version: SerializedThread::VERSION.to_string(),
920 summary: this.summary_or_default(),
921 updated_at: this.updated_at(),
922 messages: this
923 .messages()
924 .map(|message| SerializedMessage {
925 id: message.id,
926 role: message.role,
927 segments: message
928 .segments
929 .iter()
930 .map(|segment| match segment {
931 MessageSegment::Text(text) => {
932 SerializedMessageSegment::Text { text: text.clone() }
933 }
934 MessageSegment::Thinking { text, signature } => {
935 SerializedMessageSegment::Thinking {
936 text: text.clone(),
937 signature: signature.clone(),
938 }
939 }
940 MessageSegment::RedactedThinking(data) => {
941 SerializedMessageSegment::RedactedThinking {
942 data: data.clone(),
943 }
944 }
945 })
946 .collect(),
947 tool_uses: this
948 .tool_uses_for_message(message.id, cx)
949 .into_iter()
950 .map(|tool_use| SerializedToolUse {
951 id: tool_use.id,
952 name: tool_use.name,
953 input: tool_use.input,
954 })
955 .collect(),
956 tool_results: this
957 .tool_results_for_message(message.id)
958 .into_iter()
959 .map(|tool_result| SerializedToolResult {
960 tool_use_id: tool_result.tool_use_id.clone(),
961 is_error: tool_result.is_error,
962 content: tool_result.content.clone(),
963 })
964 .collect(),
965 context: message.context.clone(),
966 })
967 .collect(),
968 initial_project_snapshot,
969 cumulative_token_usage: this.cumulative_token_usage,
970 request_token_usage: this.request_token_usage.clone(),
971 detailed_summary_state: this.detailed_summary_state.clone(),
972 exceeded_window_error: this.exceeded_window_error.clone(),
973 })
974 })
975 }
976
977 pub fn remaining_turns(&self) -> u32 {
978 self.remaining_turns
979 }
980
981 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
982 self.remaining_turns = remaining_turns;
983 }
984
985 pub fn send_to_model(
986 &mut self,
987 model: Arc<dyn LanguageModel>,
988 window: Option<AnyWindowHandle>,
989 cx: &mut Context<Self>,
990 ) {
991 if self.remaining_turns == 0 {
992 return;
993 }
994
995 self.remaining_turns -= 1;
996
997 let mut request = self.to_completion_request(cx);
998 if model.supports_tools() {
999 request.tools = {
1000 let mut tools = Vec::new();
1001 tools.extend(
1002 self.tools()
1003 .read(cx)
1004 .enabled_tools(cx)
1005 .into_iter()
1006 .filter_map(|tool| {
1007 // Skip tools that cannot be supported
1008 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
1009 Some(LanguageModelRequestTool {
1010 name: tool.name(),
1011 description: tool.description(),
1012 input_schema,
1013 })
1014 }),
1015 );
1016
1017 tools
1018 };
1019 }
1020
1021 self.stream_completion(request, model, window, cx);
1022 }
1023
1024 pub fn used_tools_since_last_user_message(&self) -> bool {
1025 for message in self.messages.iter().rev() {
1026 if self.tool_use.message_has_tool_results(message.id) {
1027 return true;
1028 } else if message.role == Role::User {
1029 return false;
1030 }
1031 }
1032
1033 false
1034 }
1035
1036 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1037 let mut request = LanguageModelRequest {
1038 thread_id: Some(self.id.to_string()),
1039 prompt_id: Some(self.last_prompt_id.to_string()),
1040 messages: vec![],
1041 tools: Vec::new(),
1042 stop: Vec::new(),
1043 temperature: None,
1044 };
1045
1046 if let Some(project_context) = self.project_context.borrow().as_ref() {
1047 match self
1048 .prompt_builder
1049 .generate_assistant_system_prompt(project_context)
1050 {
1051 Err(err) => {
1052 let message = format!("{err:?}").into();
1053 log::error!("{message}");
1054 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1055 header: "Error generating system prompt".into(),
1056 message,
1057 }));
1058 }
1059 Ok(system_prompt) => {
1060 request.messages.push(LanguageModelRequestMessage {
1061 role: Role::System,
1062 content: vec![MessageContent::Text(system_prompt)],
1063 cache: true,
1064 });
1065 }
1066 }
1067 } else {
1068 let message = "Context for system prompt unexpectedly not ready.".into();
1069 log::error!("{message}");
1070 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1071 header: "Error generating system prompt".into(),
1072 message,
1073 }));
1074 }
1075
1076 for message in &self.messages {
1077 let mut request_message = LanguageModelRequestMessage {
1078 role: message.role,
1079 content: Vec::new(),
1080 cache: false,
1081 };
1082
1083 if !message.context.is_empty() {
1084 request_message
1085 .content
1086 .push(MessageContent::Text(message.context.to_string()));
1087 }
1088
1089 if !message.images.is_empty() {
1090 // Some providers only support image parts after an initial text part
1091 if request_message.content.is_empty() {
1092 request_message
1093 .content
1094 .push(MessageContent::Text("Images attached by user:".to_string()));
1095 }
1096
1097 for image in &message.images {
1098 request_message
1099 .content
1100 .push(MessageContent::Image(image.clone()))
1101 }
1102 }
1103
1104 for segment in &message.segments {
1105 match segment {
1106 MessageSegment::Text(text) => {
1107 if !text.is_empty() {
1108 request_message
1109 .content
1110 .push(MessageContent::Text(text.into()));
1111 }
1112 }
1113 MessageSegment::Thinking { text, signature } => {
1114 if !text.is_empty() {
1115 request_message.content.push(MessageContent::Thinking {
1116 text: text.into(),
1117 signature: signature.clone(),
1118 });
1119 }
1120 }
1121 MessageSegment::RedactedThinking(data) => {
1122 request_message
1123 .content
1124 .push(MessageContent::RedactedThinking(data.clone()));
1125 }
1126 };
1127 }
1128
1129 self.tool_use
1130 .attach_tool_uses(message.id, &mut request_message);
1131
1132 request.messages.push(request_message);
1133
1134 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1135 request.messages.push(tool_results_message);
1136 }
1137 }
1138
1139 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1140 if let Some(last) = request.messages.last_mut() {
1141 last.cache = true;
1142 }
1143
1144 self.attached_tracked_files_state(&mut request.messages, cx);
1145
1146 request
1147 }
1148
1149 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1150 let mut request = LanguageModelRequest {
1151 thread_id: None,
1152 prompt_id: None,
1153 messages: vec![],
1154 tools: Vec::new(),
1155 stop: Vec::new(),
1156 temperature: None,
1157 };
1158
1159 for message in &self.messages {
1160 let mut request_message = LanguageModelRequestMessage {
1161 role: message.role,
1162 content: Vec::new(),
1163 cache: false,
1164 };
1165
1166 for segment in &message.segments {
1167 match segment {
1168 MessageSegment::Text(text) => request_message
1169 .content
1170 .push(MessageContent::Text(text.clone())),
1171 MessageSegment::Thinking { .. } => {}
1172 MessageSegment::RedactedThinking(_) => {}
1173 }
1174 }
1175
1176 if request_message.content.is_empty() {
1177 continue;
1178 }
1179
1180 request.messages.push(request_message);
1181 }
1182
1183 request.messages.push(LanguageModelRequestMessage {
1184 role: Role::User,
1185 content: vec![MessageContent::Text(added_user_message)],
1186 cache: false,
1187 });
1188
1189 request
1190 }
1191
1192 fn attached_tracked_files_state(
1193 &self,
1194 messages: &mut Vec<LanguageModelRequestMessage>,
1195 cx: &App,
1196 ) {
1197 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1198
1199 let mut stale_message = String::new();
1200
1201 let action_log = self.action_log.read(cx);
1202
1203 for stale_file in action_log.stale_buffers(cx) {
1204 let Some(file) = stale_file.read(cx).file() else {
1205 continue;
1206 };
1207
1208 if stale_message.is_empty() {
1209 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1210 }
1211
1212 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1213 }
1214
1215 let mut content = Vec::with_capacity(2);
1216
1217 if !stale_message.is_empty() {
1218 content.push(stale_message.into());
1219 }
1220
1221 if !content.is_empty() {
1222 let context_message = LanguageModelRequestMessage {
1223 role: Role::User,
1224 content,
1225 cache: false,
1226 };
1227
1228 messages.push(context_message);
1229 }
1230 }
1231
1232 pub fn stream_completion(
1233 &mut self,
1234 request: LanguageModelRequest,
1235 model: Arc<dyn LanguageModel>,
1236 window: Option<AnyWindowHandle>,
1237 cx: &mut Context<Self>,
1238 ) {
1239 let pending_completion_id = post_inc(&mut self.completion_count);
1240 let mut request_callback_parameters = if self.request_callback.is_some() {
1241 Some((request.clone(), Vec::new()))
1242 } else {
1243 None
1244 };
1245 let prompt_id = self.last_prompt_id.clone();
1246 let tool_use_metadata = ToolUseMetadata {
1247 model: model.clone(),
1248 thread_id: self.id.clone(),
1249 prompt_id: prompt_id.clone(),
1250 };
1251
1252 let task = cx.spawn(async move |thread, cx| {
1253 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1254 let initial_token_usage =
1255 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1256 let stream_completion = async {
1257 let (mut events, usage) = stream_completion_future.await?;
1258
1259 let mut stop_reason = StopReason::EndTurn;
1260 let mut current_token_usage = TokenUsage::default();
1261
1262 if let Some(usage) = usage {
1263 thread
1264 .update(cx, |_thread, cx| {
1265 cx.emit(ThreadEvent::UsageUpdated(usage));
1266 })
1267 .ok();
1268 }
1269
1270 while let Some(event) = events.next().await {
1271 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1272 response_events
1273 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1274 }
1275
1276 let event = event?;
1277
1278 thread.update(cx, |thread, cx| {
1279 match event {
1280 LanguageModelCompletionEvent::StartMessage { .. } => {
1281 thread.insert_message(
1282 Role::Assistant,
1283 vec![MessageSegment::Text(String::new())],
1284 cx,
1285 );
1286 }
1287 LanguageModelCompletionEvent::Stop(reason) => {
1288 stop_reason = reason;
1289 }
1290 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1291 thread.update_token_usage_at_last_message(token_usage);
1292 thread.cumulative_token_usage = thread.cumulative_token_usage
1293 + token_usage
1294 - current_token_usage;
1295 current_token_usage = token_usage;
1296 }
1297 LanguageModelCompletionEvent::Text(chunk) => {
1298 cx.emit(ThreadEvent::ReceivedTextChunk);
1299 if let Some(last_message) = thread.messages.last_mut() {
1300 if last_message.role == Role::Assistant
1301 && !thread.tool_use.has_tool_results(last_message.id)
1302 {
1303 last_message.push_text(&chunk);
1304 cx.emit(ThreadEvent::StreamedAssistantText(
1305 last_message.id,
1306 chunk,
1307 ));
1308 } else {
1309 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1310 // of a new Assistant response.
1311 //
1312 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1313 // will result in duplicating the text of the chunk in the rendered Markdown.
1314 thread.insert_message(
1315 Role::Assistant,
1316 vec![MessageSegment::Text(chunk.to_string())],
1317 cx,
1318 );
1319 };
1320 }
1321 }
1322 LanguageModelCompletionEvent::Thinking {
1323 text: chunk,
1324 signature,
1325 } => {
1326 if let Some(last_message) = thread.messages.last_mut() {
1327 if last_message.role == Role::Assistant
1328 && !thread.tool_use.has_tool_results(last_message.id)
1329 {
1330 last_message.push_thinking(&chunk, signature);
1331 cx.emit(ThreadEvent::StreamedAssistantThinking(
1332 last_message.id,
1333 chunk,
1334 ));
1335 } else {
1336 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1337 // of a new Assistant response.
1338 //
1339 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1340 // will result in duplicating the text of the chunk in the rendered Markdown.
1341 thread.insert_message(
1342 Role::Assistant,
1343 vec![MessageSegment::Thinking {
1344 text: chunk.to_string(),
1345 signature,
1346 }],
1347 cx,
1348 );
1349 };
1350 }
1351 }
1352 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1353 let last_assistant_message_id = thread
1354 .messages
1355 .iter_mut()
1356 .rfind(|message| message.role == Role::Assistant)
1357 .map(|message| message.id)
1358 .unwrap_or_else(|| {
1359 thread.insert_message(Role::Assistant, vec![], cx)
1360 });
1361
1362 let tool_use_id = tool_use.id.clone();
1363 let streamed_input = if tool_use.is_input_complete {
1364 None
1365 } else {
1366 Some((&tool_use.input).clone())
1367 };
1368
1369 let ui_text = thread.tool_use.request_tool_use(
1370 last_assistant_message_id,
1371 tool_use,
1372 tool_use_metadata.clone(),
1373 cx,
1374 );
1375
1376 if let Some(input) = streamed_input {
1377 cx.emit(ThreadEvent::StreamedToolUse {
1378 tool_use_id,
1379 ui_text,
1380 input,
1381 });
1382 }
1383 }
1384 }
1385
1386 thread.touch_updated_at();
1387 cx.emit(ThreadEvent::StreamedCompletion);
1388 cx.notify();
1389
1390 thread.auto_capture_telemetry(cx);
1391 })?;
1392
1393 smol::future::yield_now().await;
1394 }
1395
1396 thread.update(cx, |thread, cx| {
1397 thread
1398 .pending_completions
1399 .retain(|completion| completion.id != pending_completion_id);
1400
1401 // If there is a response without tool use, summarize the message. Otherwise,
1402 // allow two tool uses before summarizing.
1403 if thread.summary.is_none()
1404 && thread.messages.len() >= 2
1405 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1406 {
1407 thread.summarize(cx);
1408 }
1409 })?;
1410
1411 anyhow::Ok(stop_reason)
1412 };
1413
1414 let result = stream_completion.await;
1415
1416 thread
1417 .update(cx, |thread, cx| {
1418 thread.finalize_pending_checkpoint(cx);
1419 match result.as_ref() {
1420 Ok(stop_reason) => match stop_reason {
1421 StopReason::ToolUse => {
1422 let tool_uses = thread.use_pending_tools(window, cx);
1423 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1424 }
1425 StopReason::EndTurn => {}
1426 StopReason::MaxTokens => {}
1427 },
1428 Err(error) => {
1429 if error.is::<PaymentRequiredError>() {
1430 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1431 } else if error.is::<MaxMonthlySpendReachedError>() {
1432 cx.emit(ThreadEvent::ShowError(
1433 ThreadError::MaxMonthlySpendReached,
1434 ));
1435 } else if let Some(error) =
1436 error.downcast_ref::<ModelRequestLimitReachedError>()
1437 {
1438 cx.emit(ThreadEvent::ShowError(
1439 ThreadError::ModelRequestLimitReached { plan: error.plan },
1440 ));
1441 } else if let Some(known_error) =
1442 error.downcast_ref::<LanguageModelKnownError>()
1443 {
1444 match known_error {
1445 LanguageModelKnownError::ContextWindowLimitExceeded {
1446 tokens,
1447 } => {
1448 thread.exceeded_window_error = Some(ExceededWindowError {
1449 model_id: model.id(),
1450 token_count: *tokens,
1451 });
1452 cx.notify();
1453 }
1454 }
1455 } else {
1456 let error_message = error
1457 .chain()
1458 .map(|err| err.to_string())
1459 .collect::<Vec<_>>()
1460 .join("\n");
1461 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1462 header: "Error interacting with language model".into(),
1463 message: SharedString::from(error_message.clone()),
1464 }));
1465 }
1466
1467 thread.cancel_last_completion(window, cx);
1468 }
1469 }
1470 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1471
1472 if let Some((request_callback, (request, response_events))) = thread
1473 .request_callback
1474 .as_mut()
1475 .zip(request_callback_parameters.as_ref())
1476 {
1477 request_callback(request, response_events);
1478 }
1479
1480 thread.auto_capture_telemetry(cx);
1481
1482 if let Ok(initial_usage) = initial_token_usage {
1483 let usage = thread.cumulative_token_usage - initial_usage;
1484
1485 telemetry::event!(
1486 "Assistant Thread Completion",
1487 thread_id = thread.id().to_string(),
1488 prompt_id = prompt_id,
1489 model = model.telemetry_id(),
1490 model_provider = model.provider_id().to_string(),
1491 input_tokens = usage.input_tokens,
1492 output_tokens = usage.output_tokens,
1493 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1494 cache_read_input_tokens = usage.cache_read_input_tokens,
1495 );
1496 }
1497 })
1498 .ok();
1499 });
1500
1501 self.pending_completions.push(PendingCompletion {
1502 id: pending_completion_id,
1503 _task: task,
1504 });
1505 }
1506
1507 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1508 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1509 return;
1510 };
1511
1512 if !model.provider.is_authenticated(cx) {
1513 return;
1514 }
1515
1516 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1517 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1518 If the conversation is about a specific subject, include it in the title. \
1519 Be descriptive. DO NOT speak in the first person.";
1520
1521 let request = self.to_summarize_request(added_user_message.into());
1522
1523 self.pending_summary = cx.spawn(async move |this, cx| {
1524 async move {
1525 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1526 let (mut messages, usage) = stream.await?;
1527
1528 if let Some(usage) = usage {
1529 this.update(cx, |_thread, cx| {
1530 cx.emit(ThreadEvent::UsageUpdated(usage));
1531 })
1532 .ok();
1533 }
1534
1535 let mut new_summary = String::new();
1536 while let Some(message) = messages.stream.next().await {
1537 let text = message?;
1538 let mut lines = text.lines();
1539 new_summary.extend(lines.next());
1540
1541 // Stop if the LLM generated multiple lines.
1542 if lines.next().is_some() {
1543 break;
1544 }
1545 }
1546
1547 this.update(cx, |this, cx| {
1548 if !new_summary.is_empty() {
1549 this.summary = Some(new_summary.into());
1550 }
1551
1552 cx.emit(ThreadEvent::SummaryGenerated);
1553 })?;
1554
1555 anyhow::Ok(())
1556 }
1557 .log_err()
1558 .await
1559 });
1560 }
1561
1562 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1563 let last_message_id = self.messages.last().map(|message| message.id)?;
1564
1565 match &self.detailed_summary_state {
1566 DetailedSummaryState::Generating { message_id, .. }
1567 | DetailedSummaryState::Generated { message_id, .. }
1568 if *message_id == last_message_id =>
1569 {
1570 // Already up-to-date
1571 return None;
1572 }
1573 _ => {}
1574 }
1575
1576 let ConfiguredModel { model, provider } =
1577 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1578
1579 if !provider.is_authenticated(cx) {
1580 return None;
1581 }
1582
1583 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1584 1. A brief overview of what was discussed\n\
1585 2. Key facts or information discovered\n\
1586 3. Outcomes or conclusions reached\n\
1587 4. Any action items or next steps if any\n\
1588 Format it in Markdown with headings and bullet points.";
1589
1590 let request = self.to_summarize_request(added_user_message.into());
1591
1592 let task = cx.spawn(async move |thread, cx| {
1593 let stream = model.stream_completion_text(request, &cx);
1594 let Some(mut messages) = stream.await.log_err() else {
1595 thread
1596 .update(cx, |this, _cx| {
1597 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1598 })
1599 .log_err();
1600
1601 return;
1602 };
1603
1604 let mut new_detailed_summary = String::new();
1605
1606 while let Some(chunk) = messages.stream.next().await {
1607 if let Some(chunk) = chunk.log_err() {
1608 new_detailed_summary.push_str(&chunk);
1609 }
1610 }
1611
1612 thread
1613 .update(cx, |this, _cx| {
1614 this.detailed_summary_state = DetailedSummaryState::Generated {
1615 text: new_detailed_summary.into(),
1616 message_id: last_message_id,
1617 };
1618 })
1619 .log_err();
1620 });
1621
1622 self.detailed_summary_state = DetailedSummaryState::Generating {
1623 message_id: last_message_id,
1624 };
1625
1626 Some(task)
1627 }
1628
1629 pub fn is_generating_detailed_summary(&self) -> bool {
1630 matches!(
1631 self.detailed_summary_state,
1632 DetailedSummaryState::Generating { .. }
1633 )
1634 }
1635
1636 pub fn use_pending_tools(
1637 &mut self,
1638 window: Option<AnyWindowHandle>,
1639 cx: &mut Context<Self>,
1640 ) -> Vec<PendingToolUse> {
1641 self.auto_capture_telemetry(cx);
1642 let request = self.to_completion_request(cx);
1643 let messages = Arc::new(request.messages);
1644 let pending_tool_uses = self
1645 .tool_use
1646 .pending_tool_uses()
1647 .into_iter()
1648 .filter(|tool_use| tool_use.status.is_idle())
1649 .cloned()
1650 .collect::<Vec<_>>();
1651
1652 for tool_use in pending_tool_uses.iter() {
1653 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1654 if tool.needs_confirmation(&tool_use.input, cx)
1655 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1656 {
1657 self.tool_use.confirm_tool_use(
1658 tool_use.id.clone(),
1659 tool_use.ui_text.clone(),
1660 tool_use.input.clone(),
1661 messages.clone(),
1662 tool,
1663 );
1664 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1665 } else {
1666 self.run_tool(
1667 tool_use.id.clone(),
1668 tool_use.ui_text.clone(),
1669 tool_use.input.clone(),
1670 &messages,
1671 tool,
1672 window,
1673 cx,
1674 );
1675 }
1676 }
1677 }
1678
1679 pending_tool_uses
1680 }
1681
1682 pub fn run_tool(
1683 &mut self,
1684 tool_use_id: LanguageModelToolUseId,
1685 ui_text: impl Into<SharedString>,
1686 input: serde_json::Value,
1687 messages: &[LanguageModelRequestMessage],
1688 tool: Arc<dyn Tool>,
1689 window: Option<AnyWindowHandle>,
1690 cx: &mut Context<Thread>,
1691 ) {
1692 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1693 self.tool_use
1694 .run_pending_tool(tool_use_id, ui_text.into(), task);
1695 }
1696
1697 fn spawn_tool_use(
1698 &mut self,
1699 tool_use_id: LanguageModelToolUseId,
1700 messages: &[LanguageModelRequestMessage],
1701 input: serde_json::Value,
1702 tool: Arc<dyn Tool>,
1703 window: Option<AnyWindowHandle>,
1704 cx: &mut Context<Thread>,
1705 ) -> Task<()> {
1706 let tool_name: Arc<str> = tool.name().into();
1707
1708 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1709 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1710 } else {
1711 tool.run(
1712 input,
1713 messages,
1714 self.project.clone(),
1715 self.action_log.clone(),
1716 window,
1717 cx,
1718 )
1719 };
1720
1721 // Store the card separately if it exists
1722 if let Some(card) = tool_result.card.clone() {
1723 self.tool_use
1724 .insert_tool_result_card(tool_use_id.clone(), card);
1725 }
1726
1727 cx.spawn({
1728 async move |thread: WeakEntity<Thread>, cx| {
1729 let output = tool_result.output.await;
1730
1731 thread
1732 .update(cx, |thread, cx| {
1733 let pending_tool_use = thread.tool_use.insert_tool_output(
1734 tool_use_id.clone(),
1735 tool_name,
1736 output,
1737 cx,
1738 );
1739 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1740 })
1741 .ok();
1742 }
1743 })
1744 }
1745
1746 fn tool_finished(
1747 &mut self,
1748 tool_use_id: LanguageModelToolUseId,
1749 pending_tool_use: Option<PendingToolUse>,
1750 canceled: bool,
1751 window: Option<AnyWindowHandle>,
1752 cx: &mut Context<Self>,
1753 ) {
1754 if self.all_tools_finished() {
1755 let model_registry = LanguageModelRegistry::read_global(cx);
1756 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1757 if !canceled {
1758 self.send_to_model(model, window, cx);
1759 }
1760 self.auto_capture_telemetry(cx);
1761 }
1762 }
1763
1764 cx.emit(ThreadEvent::ToolFinished {
1765 tool_use_id,
1766 pending_tool_use,
1767 });
1768 }
1769
1770 /// Cancels the last pending completion, if there are any pending.
1771 ///
1772 /// Returns whether a completion was canceled.
1773 pub fn cancel_last_completion(
1774 &mut self,
1775 window: Option<AnyWindowHandle>,
1776 cx: &mut Context<Self>,
1777 ) -> bool {
1778 let canceled = if self.pending_completions.pop().is_some() {
1779 true
1780 } else {
1781 let mut canceled = false;
1782 for pending_tool_use in self.tool_use.cancel_pending() {
1783 canceled = true;
1784 self.tool_finished(
1785 pending_tool_use.id.clone(),
1786 Some(pending_tool_use),
1787 true,
1788 window,
1789 cx,
1790 );
1791 }
1792 canceled
1793 };
1794 self.finalize_pending_checkpoint(cx);
1795 canceled
1796 }
1797
1798 pub fn feedback(&self) -> Option<ThreadFeedback> {
1799 self.feedback
1800 }
1801
1802 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1803 self.message_feedback.get(&message_id).copied()
1804 }
1805
1806 pub fn report_message_feedback(
1807 &mut self,
1808 message_id: MessageId,
1809 feedback: ThreadFeedback,
1810 cx: &mut Context<Self>,
1811 ) -> Task<Result<()>> {
1812 if self.message_feedback.get(&message_id) == Some(&feedback) {
1813 return Task::ready(Ok(()));
1814 }
1815
1816 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1817 let serialized_thread = self.serialize(cx);
1818 let thread_id = self.id().clone();
1819 let client = self.project.read(cx).client();
1820
1821 let enabled_tool_names: Vec<String> = self
1822 .tools()
1823 .read(cx)
1824 .enabled_tools(cx)
1825 .iter()
1826 .map(|tool| tool.name().to_string())
1827 .collect();
1828
1829 self.message_feedback.insert(message_id, feedback);
1830
1831 cx.notify();
1832
1833 let message_content = self
1834 .message(message_id)
1835 .map(|msg| msg.to_string())
1836 .unwrap_or_default();
1837
1838 cx.background_spawn(async move {
1839 let final_project_snapshot = final_project_snapshot.await;
1840 let serialized_thread = serialized_thread.await?;
1841 let thread_data =
1842 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1843
1844 let rating = match feedback {
1845 ThreadFeedback::Positive => "positive",
1846 ThreadFeedback::Negative => "negative",
1847 };
1848 telemetry::event!(
1849 "Assistant Thread Rated",
1850 rating,
1851 thread_id,
1852 enabled_tool_names,
1853 message_id = message_id.0,
1854 message_content,
1855 thread_data,
1856 final_project_snapshot
1857 );
1858 client.telemetry().flush_events().await;
1859
1860 Ok(())
1861 })
1862 }
1863
1864 pub fn report_feedback(
1865 &mut self,
1866 feedback: ThreadFeedback,
1867 cx: &mut Context<Self>,
1868 ) -> Task<Result<()>> {
1869 let last_assistant_message_id = self
1870 .messages
1871 .iter()
1872 .rev()
1873 .find(|msg| msg.role == Role::Assistant)
1874 .map(|msg| msg.id);
1875
1876 if let Some(message_id) = last_assistant_message_id {
1877 self.report_message_feedback(message_id, feedback, cx)
1878 } else {
1879 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1880 let serialized_thread = self.serialize(cx);
1881 let thread_id = self.id().clone();
1882 let client = self.project.read(cx).client();
1883 self.feedback = Some(feedback);
1884 cx.notify();
1885
1886 cx.background_spawn(async move {
1887 let final_project_snapshot = final_project_snapshot.await;
1888 let serialized_thread = serialized_thread.await?;
1889 let thread_data = serde_json::to_value(serialized_thread)
1890 .unwrap_or_else(|_| serde_json::Value::Null);
1891
1892 let rating = match feedback {
1893 ThreadFeedback::Positive => "positive",
1894 ThreadFeedback::Negative => "negative",
1895 };
1896 telemetry::event!(
1897 "Assistant Thread Rated",
1898 rating,
1899 thread_id,
1900 thread_data,
1901 final_project_snapshot
1902 );
1903 client.telemetry().flush_events().await;
1904
1905 Ok(())
1906 })
1907 }
1908 }
1909
1910 /// Create a snapshot of the current project state including git information and unsaved buffers.
1911 fn project_snapshot(
1912 project: Entity<Project>,
1913 cx: &mut Context<Self>,
1914 ) -> Task<Arc<ProjectSnapshot>> {
1915 let git_store = project.read(cx).git_store().clone();
1916 let worktree_snapshots: Vec<_> = project
1917 .read(cx)
1918 .visible_worktrees(cx)
1919 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1920 .collect();
1921
1922 cx.spawn(async move |_, cx| {
1923 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1924
1925 let mut unsaved_buffers = Vec::new();
1926 cx.update(|app_cx| {
1927 let buffer_store = project.read(app_cx).buffer_store();
1928 for buffer_handle in buffer_store.read(app_cx).buffers() {
1929 let buffer = buffer_handle.read(app_cx);
1930 if buffer.is_dirty() {
1931 if let Some(file) = buffer.file() {
1932 let path = file.path().to_string_lossy().to_string();
1933 unsaved_buffers.push(path);
1934 }
1935 }
1936 }
1937 })
1938 .ok();
1939
1940 Arc::new(ProjectSnapshot {
1941 worktree_snapshots,
1942 unsaved_buffer_paths: unsaved_buffers,
1943 timestamp: Utc::now(),
1944 })
1945 })
1946 }
1947
1948 fn worktree_snapshot(
1949 worktree: Entity<project::Worktree>,
1950 git_store: Entity<GitStore>,
1951 cx: &App,
1952 ) -> Task<WorktreeSnapshot> {
1953 cx.spawn(async move |cx| {
1954 // Get worktree path and snapshot
1955 let worktree_info = cx.update(|app_cx| {
1956 let worktree = worktree.read(app_cx);
1957 let path = worktree.abs_path().to_string_lossy().to_string();
1958 let snapshot = worktree.snapshot();
1959 (path, snapshot)
1960 });
1961
1962 let Ok((worktree_path, _snapshot)) = worktree_info else {
1963 return WorktreeSnapshot {
1964 worktree_path: String::new(),
1965 git_state: None,
1966 };
1967 };
1968
1969 let git_state = git_store
1970 .update(cx, |git_store, cx| {
1971 git_store
1972 .repositories()
1973 .values()
1974 .find(|repo| {
1975 repo.read(cx)
1976 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1977 .is_some()
1978 })
1979 .cloned()
1980 })
1981 .ok()
1982 .flatten()
1983 .map(|repo| {
1984 repo.update(cx, |repo, _| {
1985 let current_branch =
1986 repo.branch.as_ref().map(|branch| branch.name.to_string());
1987 repo.send_job(None, |state, _| async move {
1988 let RepositoryState::Local { backend, .. } = state else {
1989 return GitState {
1990 remote_url: None,
1991 head_sha: None,
1992 current_branch,
1993 diff: None,
1994 };
1995 };
1996
1997 let remote_url = backend.remote_url("origin");
1998 let head_sha = backend.head_sha();
1999 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2000
2001 GitState {
2002 remote_url,
2003 head_sha,
2004 current_branch,
2005 diff,
2006 }
2007 })
2008 })
2009 });
2010
2011 let git_state = match git_state {
2012 Some(git_state) => match git_state.ok() {
2013 Some(git_state) => git_state.await.ok(),
2014 None => None,
2015 },
2016 None => None,
2017 };
2018
2019 WorktreeSnapshot {
2020 worktree_path,
2021 git_state,
2022 }
2023 })
2024 }
2025
2026 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2027 let mut markdown = Vec::new();
2028
2029 if let Some(summary) = self.summary() {
2030 writeln!(markdown, "# {summary}\n")?;
2031 };
2032
2033 for message in self.messages() {
2034 writeln!(
2035 markdown,
2036 "## {role}\n",
2037 role = match message.role {
2038 Role::User => "User",
2039 Role::Assistant => "Assistant",
2040 Role::System => "System",
2041 }
2042 )?;
2043
2044 if !message.context.is_empty() {
2045 writeln!(markdown, "{}", message.context)?;
2046 }
2047
2048 for segment in &message.segments {
2049 match segment {
2050 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2051 MessageSegment::Thinking { text, .. } => {
2052 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2053 }
2054 MessageSegment::RedactedThinking(_) => {}
2055 }
2056 }
2057
2058 for tool_use in self.tool_uses_for_message(message.id, cx) {
2059 writeln!(
2060 markdown,
2061 "**Use Tool: {} ({})**",
2062 tool_use.name, tool_use.id
2063 )?;
2064 writeln!(markdown, "```json")?;
2065 writeln!(
2066 markdown,
2067 "{}",
2068 serde_json::to_string_pretty(&tool_use.input)?
2069 )?;
2070 writeln!(markdown, "```")?;
2071 }
2072
2073 for tool_result in self.tool_results_for_message(message.id) {
2074 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2075 if tool_result.is_error {
2076 write!(markdown, " (Error)")?;
2077 }
2078
2079 writeln!(markdown, "**\n")?;
2080 writeln!(markdown, "{}", tool_result.content)?;
2081 }
2082 }
2083
2084 Ok(String::from_utf8_lossy(&markdown).to_string())
2085 }
2086
2087 pub fn keep_edits_in_range(
2088 &mut self,
2089 buffer: Entity<language::Buffer>,
2090 buffer_range: Range<language::Anchor>,
2091 cx: &mut Context<Self>,
2092 ) {
2093 self.action_log.update(cx, |action_log, cx| {
2094 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2095 });
2096 }
2097
2098 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2099 self.action_log
2100 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2101 }
2102
2103 pub fn reject_edits_in_ranges(
2104 &mut self,
2105 buffer: Entity<language::Buffer>,
2106 buffer_ranges: Vec<Range<language::Anchor>>,
2107 cx: &mut Context<Self>,
2108 ) -> Task<Result<()>> {
2109 self.action_log.update(cx, |action_log, cx| {
2110 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2111 })
2112 }
2113
2114 pub fn action_log(&self) -> &Entity<ActionLog> {
2115 &self.action_log
2116 }
2117
2118 pub fn project(&self) -> &Entity<Project> {
2119 &self.project
2120 }
2121
2122 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2123 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2124 return;
2125 }
2126
2127 let now = Instant::now();
2128 if let Some(last) = self.last_auto_capture_at {
2129 if now.duration_since(last).as_secs() < 10 {
2130 return;
2131 }
2132 }
2133
2134 self.last_auto_capture_at = Some(now);
2135
2136 let thread_id = self.id().clone();
2137 let github_login = self
2138 .project
2139 .read(cx)
2140 .user_store()
2141 .read(cx)
2142 .current_user()
2143 .map(|user| user.github_login.clone());
2144 let client = self.project.read(cx).client().clone();
2145 let serialize_task = self.serialize(cx);
2146
2147 cx.background_executor()
2148 .spawn(async move {
2149 if let Ok(serialized_thread) = serialize_task.await {
2150 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2151 telemetry::event!(
2152 "Agent Thread Auto-Captured",
2153 thread_id = thread_id.to_string(),
2154 thread_data = thread_data,
2155 auto_capture_reason = "tracked_user",
2156 github_login = github_login
2157 );
2158
2159 client.telemetry().flush_events().await;
2160 }
2161 }
2162 })
2163 .detach();
2164 }
2165
2166 pub fn cumulative_token_usage(&self) -> TokenUsage {
2167 self.cumulative_token_usage
2168 }
2169
2170 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2171 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2172 return TotalTokenUsage::default();
2173 };
2174
2175 let max = model.model.max_token_count();
2176
2177 let index = self
2178 .messages
2179 .iter()
2180 .position(|msg| msg.id == message_id)
2181 .unwrap_or(0);
2182
2183 if index == 0 {
2184 return TotalTokenUsage { total: 0, max };
2185 }
2186
2187 let token_usage = &self
2188 .request_token_usage
2189 .get(index - 1)
2190 .cloned()
2191 .unwrap_or_default();
2192
2193 TotalTokenUsage {
2194 total: token_usage.total_tokens() as usize,
2195 max,
2196 }
2197 }
2198
2199 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2200 let model_registry = LanguageModelRegistry::read_global(cx);
2201 let Some(model) = model_registry.default_model() else {
2202 return TotalTokenUsage::default();
2203 };
2204
2205 let max = model.model.max_token_count();
2206
2207 if let Some(exceeded_error) = &self.exceeded_window_error {
2208 if model.model.id() == exceeded_error.model_id {
2209 return TotalTokenUsage {
2210 total: exceeded_error.token_count,
2211 max,
2212 };
2213 }
2214 }
2215
2216 let total = self
2217 .token_usage_at_last_message()
2218 .unwrap_or_default()
2219 .total_tokens() as usize;
2220
2221 TotalTokenUsage { total, max }
2222 }
2223
2224 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2225 self.request_token_usage
2226 .get(self.messages.len().saturating_sub(1))
2227 .or_else(|| self.request_token_usage.last())
2228 .cloned()
2229 }
2230
2231 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2232 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2233 self.request_token_usage
2234 .resize(self.messages.len(), placeholder);
2235
2236 if let Some(last) = self.request_token_usage.last_mut() {
2237 *last = token_usage;
2238 }
2239 }
2240
2241 pub fn deny_tool_use(
2242 &mut self,
2243 tool_use_id: LanguageModelToolUseId,
2244 tool_name: Arc<str>,
2245 window: Option<AnyWindowHandle>,
2246 cx: &mut Context<Self>,
2247 ) {
2248 let err = Err(anyhow::anyhow!(
2249 "Permission to run tool action denied by user"
2250 ));
2251
2252 self.tool_use
2253 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2254 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2255 }
2256}
2257
2258#[derive(Debug, Clone, Error)]
2259pub enum ThreadError {
2260 #[error("Payment required")]
2261 PaymentRequired,
2262 #[error("Max monthly spend reached")]
2263 MaxMonthlySpendReached,
2264 #[error("Model request limit reached")]
2265 ModelRequestLimitReached { plan: Plan },
2266 #[error("Message {header}: {message}")]
2267 Message {
2268 header: SharedString,
2269 message: SharedString,
2270 },
2271}
2272
2273#[derive(Debug, Clone)]
2274pub enum ThreadEvent {
2275 ShowError(ThreadError),
2276 UsageUpdated(RequestUsage),
2277 StreamedCompletion,
2278 ReceivedTextChunk,
2279 StreamedAssistantText(MessageId, String),
2280 StreamedAssistantThinking(MessageId, String),
2281 StreamedToolUse {
2282 tool_use_id: LanguageModelToolUseId,
2283 ui_text: Arc<str>,
2284 input: serde_json::Value,
2285 },
2286 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2287 MessageAdded(MessageId),
2288 MessageEdited(MessageId),
2289 MessageDeleted(MessageId),
2290 SummaryGenerated,
2291 SummaryChanged,
2292 UsePendingTools {
2293 tool_uses: Vec<PendingToolUse>,
2294 },
2295 ToolFinished {
2296 #[allow(unused)]
2297 tool_use_id: LanguageModelToolUseId,
2298 /// The pending tool use that corresponds to this tool.
2299 pending_tool_use: Option<PendingToolUse>,
2300 },
2301 CheckpointChanged,
2302 ToolConfirmationNeeded,
2303}
2304
2305impl EventEmitter<ThreadEvent> for Thread {}
2306
2307struct PendingCompletion {
2308 id: usize,
2309 _task: Task<()>,
2310}
2311
2312#[cfg(test)]
2313mod tests {
2314 use super::*;
2315 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2316 use assistant_settings::AssistantSettings;
2317 use context_server::ContextServerSettings;
2318 use editor::EditorSettings;
2319 use gpui::TestAppContext;
2320 use project::{FakeFs, Project};
2321 use prompt_store::PromptBuilder;
2322 use serde_json::json;
2323 use settings::{Settings, SettingsStore};
2324 use std::sync::Arc;
2325 use theme::ThemeSettings;
2326 use util::path;
2327 use workspace::Workspace;
2328
2329 #[gpui::test]
2330 async fn test_message_with_context(cx: &mut TestAppContext) {
2331 init_test_settings(cx);
2332
2333 let project = create_test_project(
2334 cx,
2335 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2336 )
2337 .await;
2338
2339 let (_workspace, _thread_store, thread, context_store) =
2340 setup_test_environment(cx, project.clone()).await;
2341
2342 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2343 .await
2344 .unwrap();
2345
2346 let context =
2347 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2348
2349 // Insert user message with context
2350 let message_id = thread.update(cx, |thread, cx| {
2351 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2352 });
2353
2354 // Check content and context in message object
2355 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2356
2357 // Use different path format strings based on platform for the test
2358 #[cfg(windows)]
2359 let path_part = r"test\code.rs";
2360 #[cfg(not(windows))]
2361 let path_part = "test/code.rs";
2362
2363 let expected_context = format!(
2364 r#"
2365<context>
2366The following items were attached by the user. You don't need to use other tools to read them.
2367
2368<files>
2369```rs {path_part}
2370fn main() {{
2371 println!("Hello, world!");
2372}}
2373```
2374</files>
2375</context>
2376"#
2377 );
2378
2379 assert_eq!(message.role, Role::User);
2380 assert_eq!(message.segments.len(), 1);
2381 assert_eq!(
2382 message.segments[0],
2383 MessageSegment::Text("Please explain this code".to_string())
2384 );
2385 assert_eq!(message.context, expected_context);
2386
2387 // Check message in request
2388 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2389
2390 assert_eq!(request.messages.len(), 2);
2391 let expected_full_message = format!("{}Please explain this code", expected_context);
2392 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2393 }
2394
2395 #[gpui::test]
2396 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2397 init_test_settings(cx);
2398
2399 let project = create_test_project(
2400 cx,
2401 json!({
2402 "file1.rs": "fn function1() {}\n",
2403 "file2.rs": "fn function2() {}\n",
2404 "file3.rs": "fn function3() {}\n",
2405 }),
2406 )
2407 .await;
2408
2409 let (_, _thread_store, thread, context_store) =
2410 setup_test_environment(cx, project.clone()).await;
2411
2412 // Open files individually
2413 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2414 .await
2415 .unwrap();
2416 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2417 .await
2418 .unwrap();
2419 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2420 .await
2421 .unwrap();
2422
2423 // Get the context objects
2424 let contexts = context_store.update(cx, |store, _| store.context().clone());
2425 assert_eq!(contexts.len(), 3);
2426
2427 // First message with context 1
2428 let message1_id = thread.update(cx, |thread, cx| {
2429 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2430 });
2431
2432 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2433 let message2_id = thread.update(cx, |thread, cx| {
2434 thread.insert_user_message(
2435 "Message 2",
2436 vec![contexts[0].clone(), contexts[1].clone()],
2437 None,
2438 cx,
2439 )
2440 });
2441
2442 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2443 let message3_id = thread.update(cx, |thread, cx| {
2444 thread.insert_user_message(
2445 "Message 3",
2446 vec![
2447 contexts[0].clone(),
2448 contexts[1].clone(),
2449 contexts[2].clone(),
2450 ],
2451 None,
2452 cx,
2453 )
2454 });
2455
2456 // Check what contexts are included in each message
2457 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2458 (
2459 thread.message(message1_id).unwrap().clone(),
2460 thread.message(message2_id).unwrap().clone(),
2461 thread.message(message3_id).unwrap().clone(),
2462 )
2463 });
2464
2465 // First message should include context 1
2466 assert!(message1.context.contains("file1.rs"));
2467
2468 // Second message should include only context 2 (not 1)
2469 assert!(!message2.context.contains("file1.rs"));
2470 assert!(message2.context.contains("file2.rs"));
2471
2472 // Third message should include only context 3 (not 1 or 2)
2473 assert!(!message3.context.contains("file1.rs"));
2474 assert!(!message3.context.contains("file2.rs"));
2475 assert!(message3.context.contains("file3.rs"));
2476
2477 // Check entire request to make sure all contexts are properly included
2478 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2479
2480 // The request should contain all 3 messages
2481 assert_eq!(request.messages.len(), 4);
2482
2483 // Check that the contexts are properly formatted in each message
2484 assert!(request.messages[1].string_contents().contains("file1.rs"));
2485 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2486 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2487
2488 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2489 assert!(request.messages[2].string_contents().contains("file2.rs"));
2490 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2491
2492 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2493 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2494 assert!(request.messages[3].string_contents().contains("file3.rs"));
2495 }
2496
2497 #[gpui::test]
2498 async fn test_message_without_files(cx: &mut TestAppContext) {
2499 init_test_settings(cx);
2500
2501 let project = create_test_project(
2502 cx,
2503 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2504 )
2505 .await;
2506
2507 let (_, _thread_store, thread, _context_store) =
2508 setup_test_environment(cx, project.clone()).await;
2509
2510 // Insert user message without any context (empty context vector)
2511 let message_id = thread.update(cx, |thread, cx| {
2512 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2513 });
2514
2515 // Check content and context in message object
2516 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2517
2518 // Context should be empty when no files are included
2519 assert_eq!(message.role, Role::User);
2520 assert_eq!(message.segments.len(), 1);
2521 assert_eq!(
2522 message.segments[0],
2523 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2524 );
2525 assert_eq!(message.context, "");
2526
2527 // Check message in request
2528 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2529
2530 assert_eq!(request.messages.len(), 2);
2531 assert_eq!(
2532 request.messages[1].string_contents(),
2533 "What is the best way to learn Rust?"
2534 );
2535
2536 // Add second message, also without context
2537 let message2_id = thread.update(cx, |thread, cx| {
2538 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2539 });
2540
2541 let message2 =
2542 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2543 assert_eq!(message2.context, "");
2544
2545 // Check that both messages appear in the request
2546 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2547
2548 assert_eq!(request.messages.len(), 3);
2549 assert_eq!(
2550 request.messages[1].string_contents(),
2551 "What is the best way to learn Rust?"
2552 );
2553 assert_eq!(
2554 request.messages[2].string_contents(),
2555 "Are there any good books?"
2556 );
2557 }
2558
2559 #[gpui::test]
2560 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2561 init_test_settings(cx);
2562
2563 let project = create_test_project(
2564 cx,
2565 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2566 )
2567 .await;
2568
2569 let (_workspace, _thread_store, thread, context_store) =
2570 setup_test_environment(cx, project.clone()).await;
2571
2572 // Open buffer and add it to context
2573 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2574 .await
2575 .unwrap();
2576
2577 let context =
2578 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2579
2580 // Insert user message with the buffer as context
2581 thread.update(cx, |thread, cx| {
2582 thread.insert_user_message("Explain this code", vec![context], None, cx)
2583 });
2584
2585 // Create a request and check that it doesn't have a stale buffer warning yet
2586 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2587
2588 // Make sure we don't have a stale file warning yet
2589 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2590 msg.string_contents()
2591 .contains("These files changed since last read:")
2592 });
2593 assert!(
2594 !has_stale_warning,
2595 "Should not have stale buffer warning before buffer is modified"
2596 );
2597
2598 // Modify the buffer
2599 buffer.update(cx, |buffer, cx| {
2600 // Find a position at the end of line 1
2601 buffer.edit(
2602 [(1..1, "\n println!(\"Added a new line\");\n")],
2603 None,
2604 cx,
2605 );
2606 });
2607
2608 // Insert another user message without context
2609 thread.update(cx, |thread, cx| {
2610 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2611 });
2612
2613 // Create a new request and check for the stale buffer warning
2614 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2615
2616 // We should have a stale file warning as the last message
2617 let last_message = new_request
2618 .messages
2619 .last()
2620 .expect("Request should have messages");
2621
2622 // The last message should be the stale buffer notification
2623 assert_eq!(last_message.role, Role::User);
2624
2625 // Check the exact content of the message
2626 let expected_content = "These files changed since last read:\n- code.rs\n";
2627 assert_eq!(
2628 last_message.string_contents(),
2629 expected_content,
2630 "Last message should be exactly the stale buffer notification"
2631 );
2632 }
2633
2634 fn init_test_settings(cx: &mut TestAppContext) {
2635 cx.update(|cx| {
2636 let settings_store = SettingsStore::test(cx);
2637 cx.set_global(settings_store);
2638 language::init(cx);
2639 Project::init_settings(cx);
2640 AssistantSettings::register(cx);
2641 prompt_store::init(cx);
2642 thread_store::init(cx);
2643 workspace::init_settings(cx);
2644 ThemeSettings::register(cx);
2645 ContextServerSettings::register(cx);
2646 EditorSettings::register(cx);
2647 });
2648 }
2649
2650 // Helper to create a test project with test files
2651 async fn create_test_project(
2652 cx: &mut TestAppContext,
2653 files: serde_json::Value,
2654 ) -> Entity<Project> {
2655 let fs = FakeFs::new(cx.executor());
2656 fs.insert_tree(path!("/test"), files).await;
2657 Project::test(fs, [path!("/test").as_ref()], cx).await
2658 }
2659
2660 async fn setup_test_environment(
2661 cx: &mut TestAppContext,
2662 project: Entity<Project>,
2663 ) -> (
2664 Entity<Workspace>,
2665 Entity<ThreadStore>,
2666 Entity<Thread>,
2667 Entity<ContextStore>,
2668 ) {
2669 let (workspace, cx) =
2670 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2671
2672 let thread_store = cx
2673 .update(|_, cx| {
2674 ThreadStore::load(
2675 project.clone(),
2676 cx.new(|_| ToolWorkingSet::default()),
2677 Arc::new(PromptBuilder::new(None).unwrap()),
2678 cx,
2679 )
2680 })
2681 .await
2682 .unwrap();
2683
2684 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2685 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2686
2687 (workspace, thread_store, thread, context_store)
2688 }
2689
2690 async fn add_file_to_context(
2691 project: &Entity<Project>,
2692 context_store: &Entity<ContextStore>,
2693 path: &str,
2694 cx: &mut TestAppContext,
2695 ) -> Result<Entity<language::Buffer>> {
2696 let buffer_path = project
2697 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2698 .unwrap();
2699
2700 let buffer = project
2701 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2702 .await
2703 .unwrap();
2704
2705 context_store
2706 .update(cx, |store, cx| {
2707 store.add_file_from_buffer(buffer.clone(), cx)
2708 })
2709 .await?;
2710
2711 Ok(buffer)
2712 }
2713}