1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
21 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
22 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
23 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::PromptBuilder;
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37
38use crate::context::{AssistantContext, ContextId, format_context_as_string};
39use crate::thread_store::{
40 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
41 SerializedToolUse, SharedProjectContext,
42};
43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
44
45#[derive(
46 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
47)]
48pub struct ThreadId(Arc<str>);
49
50impl ThreadId {
51 pub fn new() -> Self {
52 Self(Uuid::new_v4().to_string().into())
53 }
54}
55
56impl std::fmt::Display for ThreadId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl From<&str> for ThreadId {
63 fn from(value: &str) -> Self {
64 Self(value.into())
65 }
66}
67
68/// The ID of the user prompt that initiated a request.
69///
70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
72pub struct PromptId(Arc<str>);
73
74impl PromptId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for PromptId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
87pub struct MessageId(pub(crate) usize);
88
89impl MessageId {
90 fn post_inc(&mut self) -> Self {
91 Self(post_inc(&mut self.0))
92 }
93}
94
95/// A message in a [`Thread`].
96#[derive(Debug, Clone)]
97pub struct Message {
98 pub id: MessageId,
99 pub role: Role,
100 pub segments: Vec<MessageSegment>,
101 pub context: String,
102 pub images: Vec<LanguageModelImage>,
103}
104
105impl Message {
106 /// Returns whether the message contains any meaningful text that should be displayed
107 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
108 pub fn should_display_content(&self) -> bool {
109 self.segments.iter().all(|segment| segment.should_display())
110 }
111
112 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
113 if let Some(MessageSegment::Thinking {
114 text: segment,
115 signature: current_signature,
116 }) = self.segments.last_mut()
117 {
118 if let Some(signature) = signature {
119 *current_signature = Some(signature);
120 }
121 segment.push_str(text);
122 } else {
123 self.segments.push(MessageSegment::Thinking {
124 text: text.to_string(),
125 signature,
126 });
127 }
128 }
129
130 pub fn push_text(&mut self, text: &str) {
131 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
132 segment.push_str(text);
133 } else {
134 self.segments.push(MessageSegment::Text(text.to_string()));
135 }
136 }
137
138 pub fn to_string(&self) -> String {
139 let mut result = String::new();
140
141 if !self.context.is_empty() {
142 result.push_str(&self.context);
143 }
144
145 for segment in &self.segments {
146 match segment {
147 MessageSegment::Text(text) => result.push_str(text),
148 MessageSegment::Thinking { text, .. } => {
149 result.push_str("<think>\n");
150 result.push_str(text);
151 result.push_str("\n</think>");
152 }
153 MessageSegment::RedactedThinking(_) => {}
154 }
155 }
156
157 result
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub enum MessageSegment {
163 Text(String),
164 Thinking {
165 text: String,
166 signature: Option<String>,
167 },
168 RedactedThinking(Vec<u8>),
169}
170
171impl MessageSegment {
172 pub fn should_display(&self) -> bool {
173 match self {
174 Self::Text(text) => text.is_empty(),
175 Self::Thinking { text, .. } => text.is_empty(),
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320 remaining_turns: u32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ExceededWindowError {
325 /// Model used when last message exceeded context window
326 model_id: LanguageModelId,
327 /// Token count including last message
328 token_count: usize,
329}
330
331impl Thread {
332 pub fn new(
333 project: Entity<Project>,
334 tools: Entity<ToolWorkingSet>,
335 prompt_builder: Arc<PromptBuilder>,
336 system_prompt: SharedProjectContext,
337 cx: &mut Context<Self>,
338 ) -> Self {
339 Self {
340 id: ThreadId::new(),
341 updated_at: Utc::now(),
342 summary: None,
343 pending_summary: Task::ready(None),
344 detailed_summary_state: DetailedSummaryState::NotGenerated,
345 messages: Vec::new(),
346 next_message_id: MessageId(0),
347 last_prompt_id: PromptId::new(),
348 context: BTreeMap::default(),
349 context_by_message: HashMap::default(),
350 project_context: system_prompt,
351 checkpoints_by_message: HashMap::default(),
352 completion_count: 0,
353 pending_completions: Vec::new(),
354 project: project.clone(),
355 prompt_builder,
356 tools: tools.clone(),
357 last_restore_checkpoint: None,
358 pending_checkpoint: None,
359 tool_use: ToolUseState::new(tools.clone()),
360 action_log: cx.new(|_| ActionLog::new(project.clone())),
361 initial_project_snapshot: {
362 let project_snapshot = Self::project_snapshot(project, cx);
363 cx.foreground_executor()
364 .spawn(async move { Some(project_snapshot.await) })
365 .shared()
366 },
367 request_token_usage: Vec::new(),
368 cumulative_token_usage: TokenUsage::default(),
369 exceeded_window_error: None,
370 feedback: None,
371 message_feedback: HashMap::default(),
372 last_auto_capture_at: None,
373 request_callback: None,
374 remaining_turns: u32::MAX,
375 }
376 }
377
378 pub fn deserialize(
379 id: ThreadId,
380 serialized: SerializedThread,
381 project: Entity<Project>,
382 tools: Entity<ToolWorkingSet>,
383 prompt_builder: Arc<PromptBuilder>,
384 project_context: SharedProjectContext,
385 cx: &mut Context<Self>,
386 ) -> Self {
387 let next_message_id = MessageId(
388 serialized
389 .messages
390 .last()
391 .map(|message| message.id.0 + 1)
392 .unwrap_or(0),
393 );
394 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
395
396 Self {
397 id,
398 updated_at: serialized.updated_at,
399 summary: Some(serialized.summary),
400 pending_summary: Task::ready(None),
401 detailed_summary_state: serialized.detailed_summary_state,
402 messages: serialized
403 .messages
404 .into_iter()
405 .map(|message| Message {
406 id: message.id,
407 role: message.role,
408 segments: message
409 .segments
410 .into_iter()
411 .map(|segment| match segment {
412 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
413 SerializedMessageSegment::Thinking { text, signature } => {
414 MessageSegment::Thinking { text, signature }
415 }
416 SerializedMessageSegment::RedactedThinking { data } => {
417 MessageSegment::RedactedThinking(data)
418 }
419 })
420 .collect(),
421 context: message.context,
422 images: Vec::new(),
423 })
424 .collect(),
425 next_message_id,
426 last_prompt_id: PromptId::new(),
427 context: BTreeMap::default(),
428 context_by_message: HashMap::default(),
429 project_context,
430 checkpoints_by_message: HashMap::default(),
431 completion_count: 0,
432 pending_completions: Vec::new(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 project: project.clone(),
436 prompt_builder,
437 tools,
438 tool_use,
439 action_log: cx.new(|_| ActionLog::new(project)),
440 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
441 request_token_usage: serialized.request_token_usage,
442 cumulative_token_usage: serialized.cumulative_token_usage,
443 exceeded_window_error: None,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 request_callback: None,
448 remaining_turns: u32::MAX,
449 }
450 }
451
452 pub fn set_request_callback(
453 &mut self,
454 callback: impl 'static
455 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
456 ) {
457 self.request_callback = Some(Box::new(callback));
458 }
459
460 pub fn id(&self) -> &ThreadId {
461 &self.id
462 }
463
464 pub fn is_empty(&self) -> bool {
465 self.messages.is_empty()
466 }
467
468 pub fn updated_at(&self) -> DateTime<Utc> {
469 self.updated_at
470 }
471
472 pub fn touch_updated_at(&mut self) {
473 self.updated_at = Utc::now();
474 }
475
476 pub fn advance_prompt_id(&mut self) {
477 self.last_prompt_id = PromptId::new();
478 }
479
480 pub fn summary(&self) -> Option<SharedString> {
481 self.summary.clone()
482 }
483
484 pub fn project_context(&self) -> SharedProjectContext {
485 self.project_context.clone()
486 }
487
488 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
489
490 pub fn summary_or_default(&self) -> SharedString {
491 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
492 }
493
494 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
495 let Some(current_summary) = &self.summary else {
496 // Don't allow setting summary until generated
497 return;
498 };
499
500 let mut new_summary = new_summary.into();
501
502 if new_summary.is_empty() {
503 new_summary = Self::DEFAULT_SUMMARY;
504 }
505
506 if current_summary != &new_summary {
507 self.summary = Some(new_summary);
508 cx.emit(ThreadEvent::SummaryChanged);
509 }
510 }
511
512 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
513 self.latest_detailed_summary()
514 .unwrap_or_else(|| self.text().into())
515 }
516
517 fn latest_detailed_summary(&self) -> Option<SharedString> {
518 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
519 Some(text.clone())
520 } else {
521 None
522 }
523 }
524
525 pub fn message(&self, id: MessageId) -> Option<&Message> {
526 let index = self
527 .messages
528 .binary_search_by(|message| message.id.cmp(&id))
529 .ok()?;
530
531 self.messages.get(index)
532 }
533
534 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
535 self.messages.iter()
536 }
537
538 pub fn is_generating(&self) -> bool {
539 !self.pending_completions.is_empty() || !self.all_tools_finished()
540 }
541
542 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
543 &self.tools
544 }
545
546 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
547 self.tool_use
548 .pending_tool_uses()
549 .into_iter()
550 .find(|tool_use| &tool_use.id == id)
551 }
552
553 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
554 self.tool_use
555 .pending_tool_uses()
556 .into_iter()
557 .filter(|tool_use| tool_use.status.needs_confirmation())
558 }
559
560 pub fn has_pending_tool_uses(&self) -> bool {
561 !self.tool_use.pending_tool_uses().is_empty()
562 }
563
564 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
565 self.checkpoints_by_message.get(&id).cloned()
566 }
567
568 pub fn restore_checkpoint(
569 &mut self,
570 checkpoint: ThreadCheckpoint,
571 cx: &mut Context<Self>,
572 ) -> Task<Result<()>> {
573 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
574 message_id: checkpoint.message_id,
575 });
576 cx.emit(ThreadEvent::CheckpointChanged);
577 cx.notify();
578
579 let git_store = self.project().read(cx).git_store().clone();
580 let restore = git_store.update(cx, |git_store, cx| {
581 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
582 });
583
584 cx.spawn(async move |this, cx| {
585 let result = restore.await;
586 this.update(cx, |this, cx| {
587 if let Err(err) = result.as_ref() {
588 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
589 message_id: checkpoint.message_id,
590 error: err.to_string(),
591 });
592 } else {
593 this.truncate(checkpoint.message_id, cx);
594 this.last_restore_checkpoint = None;
595 }
596 this.pending_checkpoint = None;
597 cx.emit(ThreadEvent::CheckpointChanged);
598 cx.notify();
599 })?;
600 result
601 })
602 }
603
604 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
605 let pending_checkpoint = if self.is_generating() {
606 return;
607 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
608 checkpoint
609 } else {
610 return;
611 };
612
613 let git_store = self.project.read(cx).git_store().clone();
614 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
615 cx.spawn(async move |this, cx| match final_checkpoint.await {
616 Ok(final_checkpoint) => {
617 let equal = git_store
618 .update(cx, |store, cx| {
619 store.compare_checkpoints(
620 pending_checkpoint.git_checkpoint.clone(),
621 final_checkpoint.clone(),
622 cx,
623 )
624 })?
625 .await
626 .unwrap_or(false);
627
628 if !equal {
629 this.update(cx, |this, cx| {
630 this.insert_checkpoint(pending_checkpoint, cx)
631 })?;
632 }
633
634 Ok(())
635 }
636 Err(_) => this.update(cx, |this, cx| {
637 this.insert_checkpoint(pending_checkpoint, cx)
638 }),
639 })
640 .detach();
641 }
642
643 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
644 self.checkpoints_by_message
645 .insert(checkpoint.message_id, checkpoint);
646 cx.emit(ThreadEvent::CheckpointChanged);
647 cx.notify();
648 }
649
650 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
651 self.last_restore_checkpoint.as_ref()
652 }
653
654 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
655 let Some(message_ix) = self
656 .messages
657 .iter()
658 .rposition(|message| message.id == message_id)
659 else {
660 return;
661 };
662 for deleted_message in self.messages.drain(message_ix..) {
663 self.context_by_message.remove(&deleted_message.id);
664 self.checkpoints_by_message.remove(&deleted_message.id);
665 }
666 cx.notify();
667 }
668
669 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
670 self.context_by_message
671 .get(&id)
672 .into_iter()
673 .flat_map(|context| {
674 context
675 .iter()
676 .filter_map(|context_id| self.context.get(&context_id))
677 })
678 }
679
680 pub fn is_turn_end(&self, ix: usize) -> bool {
681 if self.messages.is_empty() {
682 return false;
683 }
684
685 if !self.is_generating() && ix == self.messages.len() - 1 {
686 return true;
687 }
688
689 let Some(message) = self.messages.get(ix) else {
690 return false;
691 };
692
693 if message.role != Role::Assistant {
694 return false;
695 }
696
697 self.messages
698 .get(ix + 1)
699 .and_then(|message| {
700 self.message(message.id)
701 .map(|next_message| next_message.role == Role::User)
702 })
703 .unwrap_or(false)
704 }
705
706 /// Returns whether all of the tool uses have finished running.
707 pub fn all_tools_finished(&self) -> bool {
708 // If the only pending tool uses left are the ones with errors, then
709 // that means that we've finished running all of the pending tools.
710 self.tool_use
711 .pending_tool_uses()
712 .iter()
713 .all(|tool_use| tool_use.status.is_error())
714 }
715
716 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
717 self.tool_use.tool_uses_for_message(id, cx)
718 }
719
720 pub fn tool_results_for_message(
721 &self,
722 assistant_message_id: MessageId,
723 ) -> Vec<&LanguageModelToolResult> {
724 self.tool_use.tool_results_for_message(assistant_message_id)
725 }
726
727 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
728 self.tool_use.tool_result(id)
729 }
730
731 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
732 Some(&self.tool_use.tool_result(id)?.content)
733 }
734
735 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
736 self.tool_use.tool_result_card(id).cloned()
737 }
738
739 /// Filter out contexts that have already been included in previous messages
740 pub fn filter_new_context<'a>(
741 &self,
742 context: impl Iterator<Item = &'a AssistantContext>,
743 ) -> impl Iterator<Item = &'a AssistantContext> {
744 context.filter(|ctx| self.is_context_new(ctx))
745 }
746
747 fn is_context_new(&self, context: &AssistantContext) -> bool {
748 !self.context.contains_key(&context.id())
749 }
750
751 pub fn insert_user_message(
752 &mut self,
753 text: impl Into<String>,
754 context: Vec<AssistantContext>,
755 git_checkpoint: Option<GitStoreCheckpoint>,
756 cx: &mut Context<Self>,
757 ) -> MessageId {
758 let text = text.into();
759
760 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
761
762 let new_context: Vec<_> = context
763 .into_iter()
764 .filter(|ctx| self.is_context_new(ctx))
765 .collect();
766
767 if !new_context.is_empty() {
768 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
769 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
770 message.context = context_string;
771 }
772 }
773
774 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
775 message.images = new_context
776 .iter()
777 .filter_map(|context| {
778 if let AssistantContext::Image(image_context) = context {
779 image_context.image_task.clone().now_or_never().flatten()
780 } else {
781 None
782 }
783 })
784 .collect::<Vec<_>>();
785 }
786
787 self.action_log.update(cx, |log, cx| {
788 // Track all buffers added as context
789 for ctx in &new_context {
790 match ctx {
791 AssistantContext::File(file_ctx) => {
792 log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
793 }
794 AssistantContext::Directory(dir_ctx) => {
795 for context_buffer in &dir_ctx.context_buffers {
796 log.track_buffer(context_buffer.buffer.clone(), cx);
797 }
798 }
799 AssistantContext::Symbol(symbol_ctx) => {
800 log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
801 }
802 AssistantContext::Selection(selection_context) => {
803 log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
804 }
805 AssistantContext::FetchedUrl(_)
806 | AssistantContext::Thread(_)
807 | AssistantContext::Rules(_)
808 | AssistantContext::Image(_) => {}
809 }
810 }
811 });
812 }
813
814 let context_ids = new_context
815 .iter()
816 .map(|context| context.id())
817 .collect::<Vec<_>>();
818 self.context.extend(
819 new_context
820 .into_iter()
821 .map(|context| (context.id(), context)),
822 );
823 self.context_by_message.insert(message_id, context_ids);
824
825 if let Some(git_checkpoint) = git_checkpoint {
826 self.pending_checkpoint = Some(ThreadCheckpoint {
827 message_id,
828 git_checkpoint,
829 });
830 }
831
832 self.auto_capture_telemetry(cx);
833
834 message_id
835 }
836
837 pub fn insert_message(
838 &mut self,
839 role: Role,
840 segments: Vec<MessageSegment>,
841 cx: &mut Context<Self>,
842 ) -> MessageId {
843 let id = self.next_message_id.post_inc();
844 self.messages.push(Message {
845 id,
846 role,
847 segments,
848 context: String::new(),
849 images: Vec::new(),
850 });
851 self.touch_updated_at();
852 cx.emit(ThreadEvent::MessageAdded(id));
853 id
854 }
855
856 pub fn edit_message(
857 &mut self,
858 id: MessageId,
859 new_role: Role,
860 new_segments: Vec<MessageSegment>,
861 cx: &mut Context<Self>,
862 ) -> bool {
863 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
864 return false;
865 };
866 message.role = new_role;
867 message.segments = new_segments;
868 self.touch_updated_at();
869 cx.emit(ThreadEvent::MessageEdited(id));
870 true
871 }
872
873 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
874 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
875 return false;
876 };
877 self.messages.remove(index);
878 self.context_by_message.remove(&id);
879 self.touch_updated_at();
880 cx.emit(ThreadEvent::MessageDeleted(id));
881 true
882 }
883
884 /// Returns the representation of this [`Thread`] in a textual form.
885 ///
886 /// This is the representation we use when attaching a thread as context to another thread.
887 pub fn text(&self) -> String {
888 let mut text = String::new();
889
890 for message in &self.messages {
891 text.push_str(match message.role {
892 language_model::Role::User => "User:",
893 language_model::Role::Assistant => "Assistant:",
894 language_model::Role::System => "System:",
895 });
896 text.push('\n');
897
898 for segment in &message.segments {
899 match segment {
900 MessageSegment::Text(content) => text.push_str(content),
901 MessageSegment::Thinking { text: content, .. } => {
902 text.push_str(&format!("<think>{}</think>", content))
903 }
904 MessageSegment::RedactedThinking(_) => {}
905 }
906 }
907 text.push('\n');
908 }
909
910 text
911 }
912
913 /// Serializes this thread into a format for storage or telemetry.
914 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
915 let initial_project_snapshot = self.initial_project_snapshot.clone();
916 cx.spawn(async move |this, cx| {
917 let initial_project_snapshot = initial_project_snapshot.await;
918 this.read_with(cx, |this, cx| SerializedThread {
919 version: SerializedThread::VERSION.to_string(),
920 summary: this.summary_or_default(),
921 updated_at: this.updated_at(),
922 messages: this
923 .messages()
924 .map(|message| SerializedMessage {
925 id: message.id,
926 role: message.role,
927 segments: message
928 .segments
929 .iter()
930 .map(|segment| match segment {
931 MessageSegment::Text(text) => {
932 SerializedMessageSegment::Text { text: text.clone() }
933 }
934 MessageSegment::Thinking { text, signature } => {
935 SerializedMessageSegment::Thinking {
936 text: text.clone(),
937 signature: signature.clone(),
938 }
939 }
940 MessageSegment::RedactedThinking(data) => {
941 SerializedMessageSegment::RedactedThinking {
942 data: data.clone(),
943 }
944 }
945 })
946 .collect(),
947 tool_uses: this
948 .tool_uses_for_message(message.id, cx)
949 .into_iter()
950 .map(|tool_use| SerializedToolUse {
951 id: tool_use.id,
952 name: tool_use.name,
953 input: tool_use.input,
954 })
955 .collect(),
956 tool_results: this
957 .tool_results_for_message(message.id)
958 .into_iter()
959 .map(|tool_result| SerializedToolResult {
960 tool_use_id: tool_result.tool_use_id.clone(),
961 is_error: tool_result.is_error,
962 content: tool_result.content.clone(),
963 })
964 .collect(),
965 context: message.context.clone(),
966 })
967 .collect(),
968 initial_project_snapshot,
969 cumulative_token_usage: this.cumulative_token_usage,
970 request_token_usage: this.request_token_usage.clone(),
971 detailed_summary_state: this.detailed_summary_state.clone(),
972 exceeded_window_error: this.exceeded_window_error.clone(),
973 })
974 })
975 }
976
977 pub fn remaining_turns(&self) -> u32 {
978 self.remaining_turns
979 }
980
981 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
982 self.remaining_turns = remaining_turns;
983 }
984
985 pub fn send_to_model(
986 &mut self,
987 model: Arc<dyn LanguageModel>,
988 window: Option<AnyWindowHandle>,
989 cx: &mut Context<Self>,
990 ) {
991 if self.remaining_turns == 0 {
992 return;
993 }
994
995 self.remaining_turns -= 1;
996
997 let mut request = self.to_completion_request(cx);
998 if model.supports_tools() {
999 request.tools = {
1000 let mut tools = Vec::new();
1001 tools.extend(
1002 self.tools()
1003 .read(cx)
1004 .enabled_tools(cx)
1005 .into_iter()
1006 .filter_map(|tool| {
1007 // Skip tools that cannot be supported
1008 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
1009 Some(LanguageModelRequestTool {
1010 name: tool.name(),
1011 description: tool.description(),
1012 input_schema,
1013 })
1014 }),
1015 );
1016
1017 tools
1018 };
1019 }
1020
1021 self.stream_completion(request, model, window, cx);
1022 }
1023
1024 pub fn used_tools_since_last_user_message(&self) -> bool {
1025 for message in self.messages.iter().rev() {
1026 if self.tool_use.message_has_tool_results(message.id) {
1027 return true;
1028 } else if message.role == Role::User {
1029 return false;
1030 }
1031 }
1032
1033 false
1034 }
1035
1036 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1037 let mut request = LanguageModelRequest {
1038 thread_id: Some(self.id.to_string()),
1039 prompt_id: Some(self.last_prompt_id.to_string()),
1040 messages: vec![],
1041 tools: Vec::new(),
1042 stop: Vec::new(),
1043 temperature: None,
1044 };
1045
1046 if let Some(project_context) = self.project_context.borrow().as_ref() {
1047 match self
1048 .prompt_builder
1049 .generate_assistant_system_prompt(project_context)
1050 {
1051 Err(err) => {
1052 let message = format!("{err:?}").into();
1053 log::error!("{message}");
1054 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1055 header: "Error generating system prompt".into(),
1056 message,
1057 }));
1058 }
1059 Ok(system_prompt) => {
1060 request.messages.push(LanguageModelRequestMessage {
1061 role: Role::System,
1062 content: vec![MessageContent::Text(system_prompt)],
1063 cache: true,
1064 });
1065 }
1066 }
1067 } else {
1068 let message = "Context for system prompt unexpectedly not ready.".into();
1069 log::error!("{message}");
1070 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1071 header: "Error generating system prompt".into(),
1072 message,
1073 }));
1074 }
1075
1076 for message in &self.messages {
1077 let mut request_message = LanguageModelRequestMessage {
1078 role: message.role,
1079 content: Vec::new(),
1080 cache: false,
1081 };
1082
1083 if !message.context.is_empty() {
1084 request_message
1085 .content
1086 .push(MessageContent::Text(message.context.to_string()));
1087 }
1088
1089 if !message.images.is_empty() {
1090 // Some providers only support image parts after an initial text part
1091 if request_message.content.is_empty() {
1092 request_message
1093 .content
1094 .push(MessageContent::Text("Images attached by user:".to_string()));
1095 }
1096
1097 for image in &message.images {
1098 request_message
1099 .content
1100 .push(MessageContent::Image(image.clone()))
1101 }
1102 }
1103
1104 for segment in &message.segments {
1105 match segment {
1106 MessageSegment::Text(text) => {
1107 if !text.is_empty() {
1108 request_message
1109 .content
1110 .push(MessageContent::Text(text.into()));
1111 }
1112 }
1113 MessageSegment::Thinking { text, signature } => {
1114 if !text.is_empty() {
1115 request_message.content.push(MessageContent::Thinking {
1116 text: text.into(),
1117 signature: signature.clone(),
1118 });
1119 }
1120 }
1121 MessageSegment::RedactedThinking(data) => {
1122 request_message
1123 .content
1124 .push(MessageContent::RedactedThinking(data.clone()));
1125 }
1126 };
1127 }
1128
1129 self.tool_use
1130 .attach_tool_uses(message.id, &mut request_message);
1131
1132 request.messages.push(request_message);
1133
1134 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1135 request.messages.push(tool_results_message);
1136 }
1137 }
1138
1139 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1140 if let Some(last) = request.messages.last_mut() {
1141 last.cache = true;
1142 }
1143
1144 self.attached_tracked_files_state(&mut request.messages, cx);
1145
1146 request
1147 }
1148
1149 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1150 let mut request = LanguageModelRequest {
1151 thread_id: None,
1152 prompt_id: None,
1153 messages: vec![],
1154 tools: Vec::new(),
1155 stop: Vec::new(),
1156 temperature: None,
1157 };
1158
1159 for message in &self.messages {
1160 let mut request_message = LanguageModelRequestMessage {
1161 role: message.role,
1162 content: Vec::new(),
1163 cache: false,
1164 };
1165
1166 for segment in &message.segments {
1167 match segment {
1168 MessageSegment::Text(text) => request_message
1169 .content
1170 .push(MessageContent::Text(text.clone())),
1171 MessageSegment::Thinking { .. } => {}
1172 MessageSegment::RedactedThinking(_) => {}
1173 }
1174 }
1175
1176 if request_message.content.is_empty() {
1177 continue;
1178 }
1179
1180 request.messages.push(request_message);
1181 }
1182
1183 request.messages.push(LanguageModelRequestMessage {
1184 role: Role::User,
1185 content: vec![MessageContent::Text(added_user_message)],
1186 cache: false,
1187 });
1188
1189 request
1190 }
1191
1192 fn attached_tracked_files_state(
1193 &self,
1194 messages: &mut Vec<LanguageModelRequestMessage>,
1195 cx: &App,
1196 ) {
1197 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1198
1199 let mut stale_message = String::new();
1200
1201 let action_log = self.action_log.read(cx);
1202
1203 for stale_file in action_log.stale_buffers(cx) {
1204 let Some(file) = stale_file.read(cx).file() else {
1205 continue;
1206 };
1207
1208 if stale_message.is_empty() {
1209 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1210 }
1211
1212 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1213 }
1214
1215 let mut content = Vec::with_capacity(2);
1216
1217 if !stale_message.is_empty() {
1218 content.push(stale_message.into());
1219 }
1220
1221 if !content.is_empty() {
1222 let context_message = LanguageModelRequestMessage {
1223 role: Role::User,
1224 content,
1225 cache: false,
1226 };
1227
1228 messages.push(context_message);
1229 }
1230 }
1231
1232 pub fn stream_completion(
1233 &mut self,
1234 request: LanguageModelRequest,
1235 model: Arc<dyn LanguageModel>,
1236 window: Option<AnyWindowHandle>,
1237 cx: &mut Context<Self>,
1238 ) {
1239 let pending_completion_id = post_inc(&mut self.completion_count);
1240 let mut request_callback_parameters = if self.request_callback.is_some() {
1241 Some((request.clone(), Vec::new()))
1242 } else {
1243 None
1244 };
1245 let prompt_id = self.last_prompt_id.clone();
1246 let tool_use_metadata = ToolUseMetadata {
1247 model: model.clone(),
1248 thread_id: self.id.clone(),
1249 prompt_id: prompt_id.clone(),
1250 };
1251
1252 let task = cx.spawn(async move |thread, cx| {
1253 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1254 let initial_token_usage =
1255 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1256 let stream_completion = async {
1257 let (mut events, usage) = stream_completion_future.await?;
1258
1259 let mut stop_reason = StopReason::EndTurn;
1260 let mut current_token_usage = TokenUsage::default();
1261
1262 if let Some(usage) = usage {
1263 thread
1264 .update(cx, |_thread, cx| {
1265 cx.emit(ThreadEvent::UsageUpdated(usage));
1266 })
1267 .ok();
1268 }
1269
1270 let mut request_assistant_message_id = None;
1271
1272 while let Some(event) = events.next().await {
1273 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1274 response_events
1275 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1276 }
1277
1278 let event = event?;
1279
1280 thread.update(cx, |thread, cx| {
1281 match event {
1282 LanguageModelCompletionEvent::StartMessage { .. } => {
1283 request_assistant_message_id = Some(thread.insert_message(
1284 Role::Assistant,
1285 vec![MessageSegment::Text(String::new())],
1286 cx,
1287 ));
1288 }
1289 LanguageModelCompletionEvent::Stop(reason) => {
1290 stop_reason = reason;
1291 }
1292 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1293 thread.update_token_usage_at_last_message(token_usage);
1294 thread.cumulative_token_usage = thread.cumulative_token_usage
1295 + token_usage
1296 - current_token_usage;
1297 current_token_usage = token_usage;
1298 }
1299 LanguageModelCompletionEvent::Text(chunk) => {
1300 cx.emit(ThreadEvent::ReceivedTextChunk);
1301 if let Some(last_message) = thread.messages.last_mut() {
1302 if last_message.role == Role::Assistant
1303 && !thread.tool_use.has_tool_results(last_message.id)
1304 {
1305 last_message.push_text(&chunk);
1306 cx.emit(ThreadEvent::StreamedAssistantText(
1307 last_message.id,
1308 chunk,
1309 ));
1310 } else {
1311 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1312 // of a new Assistant response.
1313 //
1314 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1315 // will result in duplicating the text of the chunk in the rendered Markdown.
1316 request_assistant_message_id = Some(thread.insert_message(
1317 Role::Assistant,
1318 vec![MessageSegment::Text(chunk.to_string())],
1319 cx,
1320 ));
1321 };
1322 }
1323 }
1324 LanguageModelCompletionEvent::Thinking {
1325 text: chunk,
1326 signature,
1327 } => {
1328 if let Some(last_message) = thread.messages.last_mut() {
1329 if last_message.role == Role::Assistant
1330 && !thread.tool_use.has_tool_results(last_message.id)
1331 {
1332 last_message.push_thinking(&chunk, signature);
1333 cx.emit(ThreadEvent::StreamedAssistantThinking(
1334 last_message.id,
1335 chunk,
1336 ));
1337 } else {
1338 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1339 // of a new Assistant response.
1340 //
1341 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1342 // will result in duplicating the text of the chunk in the rendered Markdown.
1343 request_assistant_message_id = Some(thread.insert_message(
1344 Role::Assistant,
1345 vec![MessageSegment::Thinking {
1346 text: chunk.to_string(),
1347 signature,
1348 }],
1349 cx,
1350 ));
1351 };
1352 }
1353 }
1354 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1355 let last_assistant_message_id = request_assistant_message_id
1356 .unwrap_or_else(|| {
1357 let new_assistant_message_id =
1358 thread.insert_message(Role::Assistant, vec![], cx);
1359 request_assistant_message_id =
1360 Some(new_assistant_message_id);
1361 new_assistant_message_id
1362 });
1363
1364 let tool_use_id = tool_use.id.clone();
1365 let streamed_input = if tool_use.is_input_complete {
1366 None
1367 } else {
1368 Some((&tool_use.input).clone())
1369 };
1370
1371 let ui_text = thread.tool_use.request_tool_use(
1372 last_assistant_message_id,
1373 tool_use,
1374 tool_use_metadata.clone(),
1375 cx,
1376 );
1377
1378 if let Some(input) = streamed_input {
1379 cx.emit(ThreadEvent::StreamedToolUse {
1380 tool_use_id,
1381 ui_text,
1382 input,
1383 });
1384 }
1385 }
1386 }
1387
1388 thread.touch_updated_at();
1389 cx.emit(ThreadEvent::StreamedCompletion);
1390 cx.notify();
1391
1392 thread.auto_capture_telemetry(cx);
1393 })?;
1394
1395 smol::future::yield_now().await;
1396 }
1397
1398 thread.update(cx, |thread, cx| {
1399 thread
1400 .pending_completions
1401 .retain(|completion| completion.id != pending_completion_id);
1402
1403 // If there is a response without tool use, summarize the message. Otherwise,
1404 // allow two tool uses before summarizing.
1405 if thread.summary.is_none()
1406 && thread.messages.len() >= 2
1407 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1408 {
1409 thread.summarize(cx);
1410 }
1411 })?;
1412
1413 anyhow::Ok(stop_reason)
1414 };
1415
1416 let result = stream_completion.await;
1417
1418 thread
1419 .update(cx, |thread, cx| {
1420 thread.finalize_pending_checkpoint(cx);
1421 match result.as_ref() {
1422 Ok(stop_reason) => match stop_reason {
1423 StopReason::ToolUse => {
1424 let tool_uses = thread.use_pending_tools(window, cx);
1425 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1426 }
1427 StopReason::EndTurn => {}
1428 StopReason::MaxTokens => {}
1429 },
1430 Err(error) => {
1431 if error.is::<PaymentRequiredError>() {
1432 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1433 } else if error.is::<MaxMonthlySpendReachedError>() {
1434 cx.emit(ThreadEvent::ShowError(
1435 ThreadError::MaxMonthlySpendReached,
1436 ));
1437 } else if let Some(error) =
1438 error.downcast_ref::<ModelRequestLimitReachedError>()
1439 {
1440 cx.emit(ThreadEvent::ShowError(
1441 ThreadError::ModelRequestLimitReached { plan: error.plan },
1442 ));
1443 } else if let Some(known_error) =
1444 error.downcast_ref::<LanguageModelKnownError>()
1445 {
1446 match known_error {
1447 LanguageModelKnownError::ContextWindowLimitExceeded {
1448 tokens,
1449 } => {
1450 thread.exceeded_window_error = Some(ExceededWindowError {
1451 model_id: model.id(),
1452 token_count: *tokens,
1453 });
1454 cx.notify();
1455 }
1456 }
1457 } else {
1458 let error_message = error
1459 .chain()
1460 .map(|err| err.to_string())
1461 .collect::<Vec<_>>()
1462 .join("\n");
1463 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1464 header: "Error interacting with language model".into(),
1465 message: SharedString::from(error_message.clone()),
1466 }));
1467 }
1468
1469 thread.cancel_last_completion(window, cx);
1470 }
1471 }
1472 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1473
1474 if let Some((request_callback, (request, response_events))) = thread
1475 .request_callback
1476 .as_mut()
1477 .zip(request_callback_parameters.as_ref())
1478 {
1479 request_callback(request, response_events);
1480 }
1481
1482 thread.auto_capture_telemetry(cx);
1483
1484 if let Ok(initial_usage) = initial_token_usage {
1485 let usage = thread.cumulative_token_usage - initial_usage;
1486
1487 telemetry::event!(
1488 "Assistant Thread Completion",
1489 thread_id = thread.id().to_string(),
1490 prompt_id = prompt_id,
1491 model = model.telemetry_id(),
1492 model_provider = model.provider_id().to_string(),
1493 input_tokens = usage.input_tokens,
1494 output_tokens = usage.output_tokens,
1495 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1496 cache_read_input_tokens = usage.cache_read_input_tokens,
1497 );
1498 }
1499 })
1500 .ok();
1501 });
1502
1503 self.pending_completions.push(PendingCompletion {
1504 id: pending_completion_id,
1505 _task: task,
1506 });
1507 }
1508
1509 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1510 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1511 return;
1512 };
1513
1514 if !model.provider.is_authenticated(cx) {
1515 return;
1516 }
1517
1518 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1519 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1520 If the conversation is about a specific subject, include it in the title. \
1521 Be descriptive. DO NOT speak in the first person.";
1522
1523 let request = self.to_summarize_request(added_user_message.into());
1524
1525 self.pending_summary = cx.spawn(async move |this, cx| {
1526 async move {
1527 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1528 let (mut messages, usage) = stream.await?;
1529
1530 if let Some(usage) = usage {
1531 this.update(cx, |_thread, cx| {
1532 cx.emit(ThreadEvent::UsageUpdated(usage));
1533 })
1534 .ok();
1535 }
1536
1537 let mut new_summary = String::new();
1538 while let Some(message) = messages.stream.next().await {
1539 let text = message?;
1540 let mut lines = text.lines();
1541 new_summary.extend(lines.next());
1542
1543 // Stop if the LLM generated multiple lines.
1544 if lines.next().is_some() {
1545 break;
1546 }
1547 }
1548
1549 this.update(cx, |this, cx| {
1550 if !new_summary.is_empty() {
1551 this.summary = Some(new_summary.into());
1552 }
1553
1554 cx.emit(ThreadEvent::SummaryGenerated);
1555 })?;
1556
1557 anyhow::Ok(())
1558 }
1559 .log_err()
1560 .await
1561 });
1562 }
1563
1564 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1565 let last_message_id = self.messages.last().map(|message| message.id)?;
1566
1567 match &self.detailed_summary_state {
1568 DetailedSummaryState::Generating { message_id, .. }
1569 | DetailedSummaryState::Generated { message_id, .. }
1570 if *message_id == last_message_id =>
1571 {
1572 // Already up-to-date
1573 return None;
1574 }
1575 _ => {}
1576 }
1577
1578 let ConfiguredModel { model, provider } =
1579 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1580
1581 if !provider.is_authenticated(cx) {
1582 return None;
1583 }
1584
1585 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1586 1. A brief overview of what was discussed\n\
1587 2. Key facts or information discovered\n\
1588 3. Outcomes or conclusions reached\n\
1589 4. Any action items or next steps if any\n\
1590 Format it in Markdown with headings and bullet points.";
1591
1592 let request = self.to_summarize_request(added_user_message.into());
1593
1594 let task = cx.spawn(async move |thread, cx| {
1595 let stream = model.stream_completion_text(request, &cx);
1596 let Some(mut messages) = stream.await.log_err() else {
1597 thread
1598 .update(cx, |this, _cx| {
1599 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1600 })
1601 .log_err();
1602
1603 return;
1604 };
1605
1606 let mut new_detailed_summary = String::new();
1607
1608 while let Some(chunk) = messages.stream.next().await {
1609 if let Some(chunk) = chunk.log_err() {
1610 new_detailed_summary.push_str(&chunk);
1611 }
1612 }
1613
1614 thread
1615 .update(cx, |this, _cx| {
1616 this.detailed_summary_state = DetailedSummaryState::Generated {
1617 text: new_detailed_summary.into(),
1618 message_id: last_message_id,
1619 };
1620 })
1621 .log_err();
1622 });
1623
1624 self.detailed_summary_state = DetailedSummaryState::Generating {
1625 message_id: last_message_id,
1626 };
1627
1628 Some(task)
1629 }
1630
1631 pub fn is_generating_detailed_summary(&self) -> bool {
1632 matches!(
1633 self.detailed_summary_state,
1634 DetailedSummaryState::Generating { .. }
1635 )
1636 }
1637
1638 pub fn use_pending_tools(
1639 &mut self,
1640 window: Option<AnyWindowHandle>,
1641 cx: &mut Context<Self>,
1642 ) -> Vec<PendingToolUse> {
1643 self.auto_capture_telemetry(cx);
1644 let request = self.to_completion_request(cx);
1645 let messages = Arc::new(request.messages);
1646 let pending_tool_uses = self
1647 .tool_use
1648 .pending_tool_uses()
1649 .into_iter()
1650 .filter(|tool_use| tool_use.status.is_idle())
1651 .cloned()
1652 .collect::<Vec<_>>();
1653
1654 for tool_use in pending_tool_uses.iter() {
1655 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1656 if tool.needs_confirmation(&tool_use.input, cx)
1657 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1658 {
1659 self.tool_use.confirm_tool_use(
1660 tool_use.id.clone(),
1661 tool_use.ui_text.clone(),
1662 tool_use.input.clone(),
1663 messages.clone(),
1664 tool,
1665 );
1666 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1667 } else {
1668 self.run_tool(
1669 tool_use.id.clone(),
1670 tool_use.ui_text.clone(),
1671 tool_use.input.clone(),
1672 &messages,
1673 tool,
1674 window,
1675 cx,
1676 );
1677 }
1678 }
1679 }
1680
1681 pending_tool_uses
1682 }
1683
1684 pub fn run_tool(
1685 &mut self,
1686 tool_use_id: LanguageModelToolUseId,
1687 ui_text: impl Into<SharedString>,
1688 input: serde_json::Value,
1689 messages: &[LanguageModelRequestMessage],
1690 tool: Arc<dyn Tool>,
1691 window: Option<AnyWindowHandle>,
1692 cx: &mut Context<Thread>,
1693 ) {
1694 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1695 self.tool_use
1696 .run_pending_tool(tool_use_id, ui_text.into(), task);
1697 }
1698
1699 fn spawn_tool_use(
1700 &mut self,
1701 tool_use_id: LanguageModelToolUseId,
1702 messages: &[LanguageModelRequestMessage],
1703 input: serde_json::Value,
1704 tool: Arc<dyn Tool>,
1705 window: Option<AnyWindowHandle>,
1706 cx: &mut Context<Thread>,
1707 ) -> Task<()> {
1708 let tool_name: Arc<str> = tool.name().into();
1709
1710 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1711 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1712 } else {
1713 tool.run(
1714 input,
1715 messages,
1716 self.project.clone(),
1717 self.action_log.clone(),
1718 window,
1719 cx,
1720 )
1721 };
1722
1723 // Store the card separately if it exists
1724 if let Some(card) = tool_result.card.clone() {
1725 self.tool_use
1726 .insert_tool_result_card(tool_use_id.clone(), card);
1727 }
1728
1729 cx.spawn({
1730 async move |thread: WeakEntity<Thread>, cx| {
1731 let output = tool_result.output.await;
1732
1733 thread
1734 .update(cx, |thread, cx| {
1735 let pending_tool_use = thread.tool_use.insert_tool_output(
1736 tool_use_id.clone(),
1737 tool_name,
1738 output,
1739 cx,
1740 );
1741 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1742 })
1743 .ok();
1744 }
1745 })
1746 }
1747
1748 fn tool_finished(
1749 &mut self,
1750 tool_use_id: LanguageModelToolUseId,
1751 pending_tool_use: Option<PendingToolUse>,
1752 canceled: bool,
1753 window: Option<AnyWindowHandle>,
1754 cx: &mut Context<Self>,
1755 ) {
1756 if self.all_tools_finished() {
1757 let model_registry = LanguageModelRegistry::read_global(cx);
1758 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1759 if !canceled {
1760 self.send_to_model(model, window, cx);
1761 }
1762 self.auto_capture_telemetry(cx);
1763 }
1764 }
1765
1766 cx.emit(ThreadEvent::ToolFinished {
1767 tool_use_id,
1768 pending_tool_use,
1769 });
1770 }
1771
1772 /// Cancels the last pending completion, if there are any pending.
1773 ///
1774 /// Returns whether a completion was canceled.
1775 pub fn cancel_last_completion(
1776 &mut self,
1777 window: Option<AnyWindowHandle>,
1778 cx: &mut Context<Self>,
1779 ) -> bool {
1780 let mut canceled = self.pending_completions.pop().is_some();
1781
1782 for pending_tool_use in self.tool_use.cancel_pending() {
1783 canceled = true;
1784 self.tool_finished(
1785 pending_tool_use.id.clone(),
1786 Some(pending_tool_use),
1787 true,
1788 window,
1789 cx,
1790 );
1791 }
1792
1793 self.finalize_pending_checkpoint(cx);
1794 canceled
1795 }
1796
1797 pub fn feedback(&self) -> Option<ThreadFeedback> {
1798 self.feedback
1799 }
1800
1801 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1802 self.message_feedback.get(&message_id).copied()
1803 }
1804
1805 pub fn report_message_feedback(
1806 &mut self,
1807 message_id: MessageId,
1808 feedback: ThreadFeedback,
1809 cx: &mut Context<Self>,
1810 ) -> Task<Result<()>> {
1811 if self.message_feedback.get(&message_id) == Some(&feedback) {
1812 return Task::ready(Ok(()));
1813 }
1814
1815 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1816 let serialized_thread = self.serialize(cx);
1817 let thread_id = self.id().clone();
1818 let client = self.project.read(cx).client();
1819
1820 let enabled_tool_names: Vec<String> = self
1821 .tools()
1822 .read(cx)
1823 .enabled_tools(cx)
1824 .iter()
1825 .map(|tool| tool.name().to_string())
1826 .collect();
1827
1828 self.message_feedback.insert(message_id, feedback);
1829
1830 cx.notify();
1831
1832 let message_content = self
1833 .message(message_id)
1834 .map(|msg| msg.to_string())
1835 .unwrap_or_default();
1836
1837 cx.background_spawn(async move {
1838 let final_project_snapshot = final_project_snapshot.await;
1839 let serialized_thread = serialized_thread.await?;
1840 let thread_data =
1841 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1842
1843 let rating = match feedback {
1844 ThreadFeedback::Positive => "positive",
1845 ThreadFeedback::Negative => "negative",
1846 };
1847 telemetry::event!(
1848 "Assistant Thread Rated",
1849 rating,
1850 thread_id,
1851 enabled_tool_names,
1852 message_id = message_id.0,
1853 message_content,
1854 thread_data,
1855 final_project_snapshot
1856 );
1857 client.telemetry().flush_events().await;
1858
1859 Ok(())
1860 })
1861 }
1862
1863 pub fn report_feedback(
1864 &mut self,
1865 feedback: ThreadFeedback,
1866 cx: &mut Context<Self>,
1867 ) -> Task<Result<()>> {
1868 let last_assistant_message_id = self
1869 .messages
1870 .iter()
1871 .rev()
1872 .find(|msg| msg.role == Role::Assistant)
1873 .map(|msg| msg.id);
1874
1875 if let Some(message_id) = last_assistant_message_id {
1876 self.report_message_feedback(message_id, feedback, cx)
1877 } else {
1878 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1879 let serialized_thread = self.serialize(cx);
1880 let thread_id = self.id().clone();
1881 let client = self.project.read(cx).client();
1882 self.feedback = Some(feedback);
1883 cx.notify();
1884
1885 cx.background_spawn(async move {
1886 let final_project_snapshot = final_project_snapshot.await;
1887 let serialized_thread = serialized_thread.await?;
1888 let thread_data = serde_json::to_value(serialized_thread)
1889 .unwrap_or_else(|_| serde_json::Value::Null);
1890
1891 let rating = match feedback {
1892 ThreadFeedback::Positive => "positive",
1893 ThreadFeedback::Negative => "negative",
1894 };
1895 telemetry::event!(
1896 "Assistant Thread Rated",
1897 rating,
1898 thread_id,
1899 thread_data,
1900 final_project_snapshot
1901 );
1902 client.telemetry().flush_events().await;
1903
1904 Ok(())
1905 })
1906 }
1907 }
1908
1909 /// Create a snapshot of the current project state including git information and unsaved buffers.
1910 fn project_snapshot(
1911 project: Entity<Project>,
1912 cx: &mut Context<Self>,
1913 ) -> Task<Arc<ProjectSnapshot>> {
1914 let git_store = project.read(cx).git_store().clone();
1915 let worktree_snapshots: Vec<_> = project
1916 .read(cx)
1917 .visible_worktrees(cx)
1918 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1919 .collect();
1920
1921 cx.spawn(async move |_, cx| {
1922 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1923
1924 let mut unsaved_buffers = Vec::new();
1925 cx.update(|app_cx| {
1926 let buffer_store = project.read(app_cx).buffer_store();
1927 for buffer_handle in buffer_store.read(app_cx).buffers() {
1928 let buffer = buffer_handle.read(app_cx);
1929 if buffer.is_dirty() {
1930 if let Some(file) = buffer.file() {
1931 let path = file.path().to_string_lossy().to_string();
1932 unsaved_buffers.push(path);
1933 }
1934 }
1935 }
1936 })
1937 .ok();
1938
1939 Arc::new(ProjectSnapshot {
1940 worktree_snapshots,
1941 unsaved_buffer_paths: unsaved_buffers,
1942 timestamp: Utc::now(),
1943 })
1944 })
1945 }
1946
1947 fn worktree_snapshot(
1948 worktree: Entity<project::Worktree>,
1949 git_store: Entity<GitStore>,
1950 cx: &App,
1951 ) -> Task<WorktreeSnapshot> {
1952 cx.spawn(async move |cx| {
1953 // Get worktree path and snapshot
1954 let worktree_info = cx.update(|app_cx| {
1955 let worktree = worktree.read(app_cx);
1956 let path = worktree.abs_path().to_string_lossy().to_string();
1957 let snapshot = worktree.snapshot();
1958 (path, snapshot)
1959 });
1960
1961 let Ok((worktree_path, _snapshot)) = worktree_info else {
1962 return WorktreeSnapshot {
1963 worktree_path: String::new(),
1964 git_state: None,
1965 };
1966 };
1967
1968 let git_state = git_store
1969 .update(cx, |git_store, cx| {
1970 git_store
1971 .repositories()
1972 .values()
1973 .find(|repo| {
1974 repo.read(cx)
1975 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1976 .is_some()
1977 })
1978 .cloned()
1979 })
1980 .ok()
1981 .flatten()
1982 .map(|repo| {
1983 repo.update(cx, |repo, _| {
1984 let current_branch =
1985 repo.branch.as_ref().map(|branch| branch.name.to_string());
1986 repo.send_job(None, |state, _| async move {
1987 let RepositoryState::Local { backend, .. } = state else {
1988 return GitState {
1989 remote_url: None,
1990 head_sha: None,
1991 current_branch,
1992 diff: None,
1993 };
1994 };
1995
1996 let remote_url = backend.remote_url("origin");
1997 let head_sha = backend.head_sha();
1998 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1999
2000 GitState {
2001 remote_url,
2002 head_sha,
2003 current_branch,
2004 diff,
2005 }
2006 })
2007 })
2008 });
2009
2010 let git_state = match git_state {
2011 Some(git_state) => match git_state.ok() {
2012 Some(git_state) => git_state.await.ok(),
2013 None => None,
2014 },
2015 None => None,
2016 };
2017
2018 WorktreeSnapshot {
2019 worktree_path,
2020 git_state,
2021 }
2022 })
2023 }
2024
2025 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2026 let mut markdown = Vec::new();
2027
2028 if let Some(summary) = self.summary() {
2029 writeln!(markdown, "# {summary}\n")?;
2030 };
2031
2032 for message in self.messages() {
2033 writeln!(
2034 markdown,
2035 "## {role}\n",
2036 role = match message.role {
2037 Role::User => "User",
2038 Role::Assistant => "Assistant",
2039 Role::System => "System",
2040 }
2041 )?;
2042
2043 if !message.context.is_empty() {
2044 writeln!(markdown, "{}", message.context)?;
2045 }
2046
2047 for segment in &message.segments {
2048 match segment {
2049 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2050 MessageSegment::Thinking { text, .. } => {
2051 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2052 }
2053 MessageSegment::RedactedThinking(_) => {}
2054 }
2055 }
2056
2057 for tool_use in self.tool_uses_for_message(message.id, cx) {
2058 writeln!(
2059 markdown,
2060 "**Use Tool: {} ({})**",
2061 tool_use.name, tool_use.id
2062 )?;
2063 writeln!(markdown, "```json")?;
2064 writeln!(
2065 markdown,
2066 "{}",
2067 serde_json::to_string_pretty(&tool_use.input)?
2068 )?;
2069 writeln!(markdown, "```")?;
2070 }
2071
2072 for tool_result in self.tool_results_for_message(message.id) {
2073 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2074 if tool_result.is_error {
2075 write!(markdown, " (Error)")?;
2076 }
2077
2078 writeln!(markdown, "**\n")?;
2079 writeln!(markdown, "{}", tool_result.content)?;
2080 }
2081 }
2082
2083 Ok(String::from_utf8_lossy(&markdown).to_string())
2084 }
2085
2086 pub fn keep_edits_in_range(
2087 &mut self,
2088 buffer: Entity<language::Buffer>,
2089 buffer_range: Range<language::Anchor>,
2090 cx: &mut Context<Self>,
2091 ) {
2092 self.action_log.update(cx, |action_log, cx| {
2093 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2094 });
2095 }
2096
2097 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2098 self.action_log
2099 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2100 }
2101
2102 pub fn reject_edits_in_ranges(
2103 &mut self,
2104 buffer: Entity<language::Buffer>,
2105 buffer_ranges: Vec<Range<language::Anchor>>,
2106 cx: &mut Context<Self>,
2107 ) -> Task<Result<()>> {
2108 self.action_log.update(cx, |action_log, cx| {
2109 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2110 })
2111 }
2112
2113 pub fn action_log(&self) -> &Entity<ActionLog> {
2114 &self.action_log
2115 }
2116
2117 pub fn project(&self) -> &Entity<Project> {
2118 &self.project
2119 }
2120
2121 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2122 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2123 return;
2124 }
2125
2126 let now = Instant::now();
2127 if let Some(last) = self.last_auto_capture_at {
2128 if now.duration_since(last).as_secs() < 10 {
2129 return;
2130 }
2131 }
2132
2133 self.last_auto_capture_at = Some(now);
2134
2135 let thread_id = self.id().clone();
2136 let github_login = self
2137 .project
2138 .read(cx)
2139 .user_store()
2140 .read(cx)
2141 .current_user()
2142 .map(|user| user.github_login.clone());
2143 let client = self.project.read(cx).client().clone();
2144 let serialize_task = self.serialize(cx);
2145
2146 cx.background_executor()
2147 .spawn(async move {
2148 if let Ok(serialized_thread) = serialize_task.await {
2149 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2150 telemetry::event!(
2151 "Agent Thread Auto-Captured",
2152 thread_id = thread_id.to_string(),
2153 thread_data = thread_data,
2154 auto_capture_reason = "tracked_user",
2155 github_login = github_login
2156 );
2157
2158 client.telemetry().flush_events().await;
2159 }
2160 }
2161 })
2162 .detach();
2163 }
2164
2165 pub fn cumulative_token_usage(&self) -> TokenUsage {
2166 self.cumulative_token_usage
2167 }
2168
2169 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2170 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2171 return TotalTokenUsage::default();
2172 };
2173
2174 let max = model.model.max_token_count();
2175
2176 let index = self
2177 .messages
2178 .iter()
2179 .position(|msg| msg.id == message_id)
2180 .unwrap_or(0);
2181
2182 if index == 0 {
2183 return TotalTokenUsage { total: 0, max };
2184 }
2185
2186 let token_usage = &self
2187 .request_token_usage
2188 .get(index - 1)
2189 .cloned()
2190 .unwrap_or_default();
2191
2192 TotalTokenUsage {
2193 total: token_usage.total_tokens() as usize,
2194 max,
2195 }
2196 }
2197
2198 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2199 let model_registry = LanguageModelRegistry::read_global(cx);
2200 let Some(model) = model_registry.default_model() else {
2201 return TotalTokenUsage::default();
2202 };
2203
2204 let max = model.model.max_token_count();
2205
2206 if let Some(exceeded_error) = &self.exceeded_window_error {
2207 if model.model.id() == exceeded_error.model_id {
2208 return TotalTokenUsage {
2209 total: exceeded_error.token_count,
2210 max,
2211 };
2212 }
2213 }
2214
2215 let total = self
2216 .token_usage_at_last_message()
2217 .unwrap_or_default()
2218 .total_tokens() as usize;
2219
2220 TotalTokenUsage { total, max }
2221 }
2222
2223 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2224 self.request_token_usage
2225 .get(self.messages.len().saturating_sub(1))
2226 .or_else(|| self.request_token_usage.last())
2227 .cloned()
2228 }
2229
2230 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2231 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2232 self.request_token_usage
2233 .resize(self.messages.len(), placeholder);
2234
2235 if let Some(last) = self.request_token_usage.last_mut() {
2236 *last = token_usage;
2237 }
2238 }
2239
2240 pub fn deny_tool_use(
2241 &mut self,
2242 tool_use_id: LanguageModelToolUseId,
2243 tool_name: Arc<str>,
2244 window: Option<AnyWindowHandle>,
2245 cx: &mut Context<Self>,
2246 ) {
2247 let err = Err(anyhow::anyhow!(
2248 "Permission to run tool action denied by user"
2249 ));
2250
2251 self.tool_use
2252 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2253 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2254 }
2255}
2256
2257#[derive(Debug, Clone, Error)]
2258pub enum ThreadError {
2259 #[error("Payment required")]
2260 PaymentRequired,
2261 #[error("Max monthly spend reached")]
2262 MaxMonthlySpendReached,
2263 #[error("Model request limit reached")]
2264 ModelRequestLimitReached { plan: Plan },
2265 #[error("Message {header}: {message}")]
2266 Message {
2267 header: SharedString,
2268 message: SharedString,
2269 },
2270}
2271
2272#[derive(Debug, Clone)]
2273pub enum ThreadEvent {
2274 ShowError(ThreadError),
2275 UsageUpdated(RequestUsage),
2276 StreamedCompletion,
2277 ReceivedTextChunk,
2278 StreamedAssistantText(MessageId, String),
2279 StreamedAssistantThinking(MessageId, String),
2280 StreamedToolUse {
2281 tool_use_id: LanguageModelToolUseId,
2282 ui_text: Arc<str>,
2283 input: serde_json::Value,
2284 },
2285 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2286 MessageAdded(MessageId),
2287 MessageEdited(MessageId),
2288 MessageDeleted(MessageId),
2289 SummaryGenerated,
2290 SummaryChanged,
2291 UsePendingTools {
2292 tool_uses: Vec<PendingToolUse>,
2293 },
2294 ToolFinished {
2295 #[allow(unused)]
2296 tool_use_id: LanguageModelToolUseId,
2297 /// The pending tool use that corresponds to this tool.
2298 pending_tool_use: Option<PendingToolUse>,
2299 },
2300 CheckpointChanged,
2301 ToolConfirmationNeeded,
2302}
2303
2304impl EventEmitter<ThreadEvent> for Thread {}
2305
2306struct PendingCompletion {
2307 id: usize,
2308 _task: Task<()>,
2309}
2310
2311#[cfg(test)]
2312mod tests {
2313 use super::*;
2314 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2315 use assistant_settings::AssistantSettings;
2316 use context_server::ContextServerSettings;
2317 use editor::EditorSettings;
2318 use gpui::TestAppContext;
2319 use project::{FakeFs, Project};
2320 use prompt_store::PromptBuilder;
2321 use serde_json::json;
2322 use settings::{Settings, SettingsStore};
2323 use std::sync::Arc;
2324 use theme::ThemeSettings;
2325 use util::path;
2326 use workspace::Workspace;
2327
2328 #[gpui::test]
2329 async fn test_message_with_context(cx: &mut TestAppContext) {
2330 init_test_settings(cx);
2331
2332 let project = create_test_project(
2333 cx,
2334 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2335 )
2336 .await;
2337
2338 let (_workspace, _thread_store, thread, context_store) =
2339 setup_test_environment(cx, project.clone()).await;
2340
2341 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2342 .await
2343 .unwrap();
2344
2345 let context =
2346 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2347
2348 // Insert user message with context
2349 let message_id = thread.update(cx, |thread, cx| {
2350 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2351 });
2352
2353 // Check content and context in message object
2354 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2355
2356 // Use different path format strings based on platform for the test
2357 #[cfg(windows)]
2358 let path_part = r"test\code.rs";
2359 #[cfg(not(windows))]
2360 let path_part = "test/code.rs";
2361
2362 let expected_context = format!(
2363 r#"
2364<context>
2365The following items were attached by the user. You don't need to use other tools to read them.
2366
2367<files>
2368```rs {path_part}
2369fn main() {{
2370 println!("Hello, world!");
2371}}
2372```
2373</files>
2374</context>
2375"#
2376 );
2377
2378 assert_eq!(message.role, Role::User);
2379 assert_eq!(message.segments.len(), 1);
2380 assert_eq!(
2381 message.segments[0],
2382 MessageSegment::Text("Please explain this code".to_string())
2383 );
2384 assert_eq!(message.context, expected_context);
2385
2386 // Check message in request
2387 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2388
2389 assert_eq!(request.messages.len(), 2);
2390 let expected_full_message = format!("{}Please explain this code", expected_context);
2391 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2392 }
2393
2394 #[gpui::test]
2395 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2396 init_test_settings(cx);
2397
2398 let project = create_test_project(
2399 cx,
2400 json!({
2401 "file1.rs": "fn function1() {}\n",
2402 "file2.rs": "fn function2() {}\n",
2403 "file3.rs": "fn function3() {}\n",
2404 }),
2405 )
2406 .await;
2407
2408 let (_, _thread_store, thread, context_store) =
2409 setup_test_environment(cx, project.clone()).await;
2410
2411 // Open files individually
2412 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2413 .await
2414 .unwrap();
2415 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2416 .await
2417 .unwrap();
2418 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2419 .await
2420 .unwrap();
2421
2422 // Get the context objects
2423 let contexts = context_store.update(cx, |store, _| store.context().clone());
2424 assert_eq!(contexts.len(), 3);
2425
2426 // First message with context 1
2427 let message1_id = thread.update(cx, |thread, cx| {
2428 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2429 });
2430
2431 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2432 let message2_id = thread.update(cx, |thread, cx| {
2433 thread.insert_user_message(
2434 "Message 2",
2435 vec![contexts[0].clone(), contexts[1].clone()],
2436 None,
2437 cx,
2438 )
2439 });
2440
2441 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2442 let message3_id = thread.update(cx, |thread, cx| {
2443 thread.insert_user_message(
2444 "Message 3",
2445 vec![
2446 contexts[0].clone(),
2447 contexts[1].clone(),
2448 contexts[2].clone(),
2449 ],
2450 None,
2451 cx,
2452 )
2453 });
2454
2455 // Check what contexts are included in each message
2456 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2457 (
2458 thread.message(message1_id).unwrap().clone(),
2459 thread.message(message2_id).unwrap().clone(),
2460 thread.message(message3_id).unwrap().clone(),
2461 )
2462 });
2463
2464 // First message should include context 1
2465 assert!(message1.context.contains("file1.rs"));
2466
2467 // Second message should include only context 2 (not 1)
2468 assert!(!message2.context.contains("file1.rs"));
2469 assert!(message2.context.contains("file2.rs"));
2470
2471 // Third message should include only context 3 (not 1 or 2)
2472 assert!(!message3.context.contains("file1.rs"));
2473 assert!(!message3.context.contains("file2.rs"));
2474 assert!(message3.context.contains("file3.rs"));
2475
2476 // Check entire request to make sure all contexts are properly included
2477 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2478
2479 // The request should contain all 3 messages
2480 assert_eq!(request.messages.len(), 4);
2481
2482 // Check that the contexts are properly formatted in each message
2483 assert!(request.messages[1].string_contents().contains("file1.rs"));
2484 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2485 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2486
2487 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2488 assert!(request.messages[2].string_contents().contains("file2.rs"));
2489 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2490
2491 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2492 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2493 assert!(request.messages[3].string_contents().contains("file3.rs"));
2494 }
2495
2496 #[gpui::test]
2497 async fn test_message_without_files(cx: &mut TestAppContext) {
2498 init_test_settings(cx);
2499
2500 let project = create_test_project(
2501 cx,
2502 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2503 )
2504 .await;
2505
2506 let (_, _thread_store, thread, _context_store) =
2507 setup_test_environment(cx, project.clone()).await;
2508
2509 // Insert user message without any context (empty context vector)
2510 let message_id = thread.update(cx, |thread, cx| {
2511 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2512 });
2513
2514 // Check content and context in message object
2515 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2516
2517 // Context should be empty when no files are included
2518 assert_eq!(message.role, Role::User);
2519 assert_eq!(message.segments.len(), 1);
2520 assert_eq!(
2521 message.segments[0],
2522 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2523 );
2524 assert_eq!(message.context, "");
2525
2526 // Check message in request
2527 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2528
2529 assert_eq!(request.messages.len(), 2);
2530 assert_eq!(
2531 request.messages[1].string_contents(),
2532 "What is the best way to learn Rust?"
2533 );
2534
2535 // Add second message, also without context
2536 let message2_id = thread.update(cx, |thread, cx| {
2537 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2538 });
2539
2540 let message2 =
2541 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2542 assert_eq!(message2.context, "");
2543
2544 // Check that both messages appear in the request
2545 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2546
2547 assert_eq!(request.messages.len(), 3);
2548 assert_eq!(
2549 request.messages[1].string_contents(),
2550 "What is the best way to learn Rust?"
2551 );
2552 assert_eq!(
2553 request.messages[2].string_contents(),
2554 "Are there any good books?"
2555 );
2556 }
2557
2558 #[gpui::test]
2559 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2560 init_test_settings(cx);
2561
2562 let project = create_test_project(
2563 cx,
2564 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2565 )
2566 .await;
2567
2568 let (_workspace, _thread_store, thread, context_store) =
2569 setup_test_environment(cx, project.clone()).await;
2570
2571 // Open buffer and add it to context
2572 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2573 .await
2574 .unwrap();
2575
2576 let context =
2577 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2578
2579 // Insert user message with the buffer as context
2580 thread.update(cx, |thread, cx| {
2581 thread.insert_user_message("Explain this code", vec![context], None, cx)
2582 });
2583
2584 // Create a request and check that it doesn't have a stale buffer warning yet
2585 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2586
2587 // Make sure we don't have a stale file warning yet
2588 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2589 msg.string_contents()
2590 .contains("These files changed since last read:")
2591 });
2592 assert!(
2593 !has_stale_warning,
2594 "Should not have stale buffer warning before buffer is modified"
2595 );
2596
2597 // Modify the buffer
2598 buffer.update(cx, |buffer, cx| {
2599 // Find a position at the end of line 1
2600 buffer.edit(
2601 [(1..1, "\n println!(\"Added a new line\");\n")],
2602 None,
2603 cx,
2604 );
2605 });
2606
2607 // Insert another user message without context
2608 thread.update(cx, |thread, cx| {
2609 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2610 });
2611
2612 // Create a new request and check for the stale buffer warning
2613 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2614
2615 // We should have a stale file warning as the last message
2616 let last_message = new_request
2617 .messages
2618 .last()
2619 .expect("Request should have messages");
2620
2621 // The last message should be the stale buffer notification
2622 assert_eq!(last_message.role, Role::User);
2623
2624 // Check the exact content of the message
2625 let expected_content = "These files changed since last read:\n- code.rs\n";
2626 assert_eq!(
2627 last_message.string_contents(),
2628 expected_content,
2629 "Last message should be exactly the stale buffer notification"
2630 );
2631 }
2632
2633 fn init_test_settings(cx: &mut TestAppContext) {
2634 cx.update(|cx| {
2635 let settings_store = SettingsStore::test(cx);
2636 cx.set_global(settings_store);
2637 language::init(cx);
2638 Project::init_settings(cx);
2639 AssistantSettings::register(cx);
2640 prompt_store::init(cx);
2641 thread_store::init(cx);
2642 workspace::init_settings(cx);
2643 ThemeSettings::register(cx);
2644 ContextServerSettings::register(cx);
2645 EditorSettings::register(cx);
2646 });
2647 }
2648
2649 // Helper to create a test project with test files
2650 async fn create_test_project(
2651 cx: &mut TestAppContext,
2652 files: serde_json::Value,
2653 ) -> Entity<Project> {
2654 let fs = FakeFs::new(cx.executor());
2655 fs.insert_tree(path!("/test"), files).await;
2656 Project::test(fs, [path!("/test").as_ref()], cx).await
2657 }
2658
2659 async fn setup_test_environment(
2660 cx: &mut TestAppContext,
2661 project: Entity<Project>,
2662 ) -> (
2663 Entity<Workspace>,
2664 Entity<ThreadStore>,
2665 Entity<Thread>,
2666 Entity<ContextStore>,
2667 ) {
2668 let (workspace, cx) =
2669 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2670
2671 let thread_store = cx
2672 .update(|_, cx| {
2673 ThreadStore::load(
2674 project.clone(),
2675 cx.new(|_| ToolWorkingSet::default()),
2676 Arc::new(PromptBuilder::new(None).unwrap()),
2677 cx,
2678 )
2679 })
2680 .await
2681 .unwrap();
2682
2683 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2684 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2685
2686 (workspace, thread_store, thread, context_store)
2687 }
2688
2689 async fn add_file_to_context(
2690 project: &Entity<Project>,
2691 context_store: &Entity<ContextStore>,
2692 path: &str,
2693 cx: &mut TestAppContext,
2694 ) -> Result<Entity<language::Buffer>> {
2695 let buffer_path = project
2696 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2697 .unwrap();
2698
2699 let buffer = project
2700 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2701 .await
2702 .unwrap();
2703
2704 context_store
2705 .update(cx, |store, cx| {
2706 store.add_file_from_buffer(buffer.clone(), cx)
2707 })
2708 .await?;
2709
2710 Ok(buffer)
2711 }
2712}