1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
22 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
23 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::PromptBuilder;
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37use zed_llm_client::CompletionMode;
38
39use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
40use crate::thread_store::{
41 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
42 SerializedToolUse, SharedProjectContext,
43};
44use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
45
46#[derive(
47 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
48)]
49pub struct ThreadId(Arc<str>);
50
51impl ThreadId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for ThreadId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63impl From<&str> for ThreadId {
64 fn from(value: &str) -> Self {
65 Self(value.into())
66 }
67}
68
69/// The ID of the user prompt that initiated a request.
70///
71/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
73pub struct PromptId(Arc<str>);
74
75impl PromptId {
76 pub fn new() -> Self {
77 Self(Uuid::new_v4().to_string().into())
78 }
79}
80
81impl std::fmt::Display for PromptId {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(f, "{}", self.0)
84 }
85}
86
87#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
88pub struct MessageId(pub(crate) usize);
89
90impl MessageId {
91 fn post_inc(&mut self) -> Self {
92 Self(post_inc(&mut self.0))
93 }
94}
95
96/// A message in a [`Thread`].
97#[derive(Debug, Clone)]
98pub struct Message {
99 pub id: MessageId,
100 pub role: Role,
101 pub segments: Vec<MessageSegment>,
102 pub loaded_context: LoadedContext,
103}
104
105impl Message {
106 /// Returns whether the message contains any meaningful text that should be displayed
107 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
108 pub fn should_display_content(&self) -> bool {
109 self.segments.iter().all(|segment| segment.should_display())
110 }
111
112 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
113 if let Some(MessageSegment::Thinking {
114 text: segment,
115 signature: current_signature,
116 }) = self.segments.last_mut()
117 {
118 if let Some(signature) = signature {
119 *current_signature = Some(signature);
120 }
121 segment.push_str(text);
122 } else {
123 self.segments.push(MessageSegment::Thinking {
124 text: text.to_string(),
125 signature,
126 });
127 }
128 }
129
130 pub fn push_text(&mut self, text: &str) {
131 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
132 segment.push_str(text);
133 } else {
134 self.segments.push(MessageSegment::Text(text.to_string()));
135 }
136 }
137
138 pub fn to_string(&self) -> String {
139 let mut result = String::new();
140
141 if !self.loaded_context.text.is_empty() {
142 result.push_str(&self.loaded_context.text);
143 }
144
145 for segment in &self.segments {
146 match segment {
147 MessageSegment::Text(text) => result.push_str(text),
148 MessageSegment::Thinking { text, .. } => {
149 result.push_str("<think>\n");
150 result.push_str(text);
151 result.push_str("\n</think>");
152 }
153 MessageSegment::RedactedThinking(_) => {}
154 }
155 }
156
157 result
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub enum MessageSegment {
163 Text(String),
164 Thinking {
165 text: String,
166 signature: Option<String>,
167 },
168 RedactedThinking(Vec<u8>),
169}
170
171impl MessageSegment {
172 pub fn should_display(&self) -> bool {
173 match self {
174 Self::Text(text) => text.is_empty(),
175 Self::Thinking { text, .. } => text.is_empty(),
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 completion_mode: Option<CompletionMode>,
295 messages: Vec<Message>,
296 next_message_id: MessageId,
297 last_prompt_id: PromptId,
298 project_context: SharedProjectContext,
299 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
300 completion_count: usize,
301 pending_completions: Vec<PendingCompletion>,
302 project: Entity<Project>,
303 prompt_builder: Arc<PromptBuilder>,
304 tools: Entity<ToolWorkingSet>,
305 tool_use: ToolUseState,
306 action_log: Entity<ActionLog>,
307 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
308 pending_checkpoint: Option<ThreadCheckpoint>,
309 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
310 request_token_usage: Vec<TokenUsage>,
311 cumulative_token_usage: TokenUsage,
312 exceeded_window_error: Option<ExceededWindowError>,
313 feedback: Option<ThreadFeedback>,
314 message_feedback: HashMap<MessageId, ThreadFeedback>,
315 last_auto_capture_at: Option<Instant>,
316 request_callback: Option<
317 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
318 >,
319 remaining_turns: u32,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct ExceededWindowError {
324 /// Model used when last message exceeded context window
325 model_id: LanguageModelId,
326 /// Token count including last message
327 token_count: usize,
328}
329
330impl Thread {
331 pub fn new(
332 project: Entity<Project>,
333 tools: Entity<ToolWorkingSet>,
334 prompt_builder: Arc<PromptBuilder>,
335 system_prompt: SharedProjectContext,
336 cx: &mut Context<Self>,
337 ) -> Self {
338 Self {
339 id: ThreadId::new(),
340 updated_at: Utc::now(),
341 summary: None,
342 pending_summary: Task::ready(None),
343 detailed_summary_state: DetailedSummaryState::NotGenerated,
344 completion_mode: None,
345 messages: Vec::new(),
346 next_message_id: MessageId(0),
347 last_prompt_id: PromptId::new(),
348 project_context: system_prompt,
349 checkpoints_by_message: HashMap::default(),
350 completion_count: 0,
351 pending_completions: Vec::new(),
352 project: project.clone(),
353 prompt_builder,
354 tools: tools.clone(),
355 last_restore_checkpoint: None,
356 pending_checkpoint: None,
357 tool_use: ToolUseState::new(tools.clone()),
358 action_log: cx.new(|_| ActionLog::new(project.clone())),
359 initial_project_snapshot: {
360 let project_snapshot = Self::project_snapshot(project, cx);
361 cx.foreground_executor()
362 .spawn(async move { Some(project_snapshot.await) })
363 .shared()
364 },
365 request_token_usage: Vec::new(),
366 cumulative_token_usage: TokenUsage::default(),
367 exceeded_window_error: None,
368 feedback: None,
369 message_feedback: HashMap::default(),
370 last_auto_capture_at: None,
371 request_callback: None,
372 remaining_turns: u32::MAX,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
393
394 Self {
395 id,
396 updated_at: serialized.updated_at,
397 summary: Some(serialized.summary),
398 pending_summary: Task::ready(None),
399 detailed_summary_state: serialized.detailed_summary_state,
400 completion_mode: None,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 loaded_context: LoadedContext {
421 contexts: Vec::new(),
422 text: message.context,
423 images: Vec::new(),
424 },
425 })
426 .collect(),
427 next_message_id,
428 last_prompt_id: PromptId::new(),
429 project_context,
430 checkpoints_by_message: HashMap::default(),
431 completion_count: 0,
432 pending_completions: Vec::new(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 project: project.clone(),
436 prompt_builder,
437 tools,
438 tool_use,
439 action_log: cx.new(|_| ActionLog::new(project)),
440 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
441 request_token_usage: serialized.request_token_usage,
442 cumulative_token_usage: serialized.cumulative_token_usage,
443 exceeded_window_error: None,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 request_callback: None,
448 remaining_turns: u32::MAX,
449 }
450 }
451
452 pub fn set_request_callback(
453 &mut self,
454 callback: impl 'static
455 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
456 ) {
457 self.request_callback = Some(Box::new(callback));
458 }
459
460 pub fn id(&self) -> &ThreadId {
461 &self.id
462 }
463
464 pub fn is_empty(&self) -> bool {
465 self.messages.is_empty()
466 }
467
468 pub fn updated_at(&self) -> DateTime<Utc> {
469 self.updated_at
470 }
471
472 pub fn touch_updated_at(&mut self) {
473 self.updated_at = Utc::now();
474 }
475
476 pub fn advance_prompt_id(&mut self) {
477 self.last_prompt_id = PromptId::new();
478 }
479
480 pub fn summary(&self) -> Option<SharedString> {
481 self.summary.clone()
482 }
483
484 pub fn project_context(&self) -> SharedProjectContext {
485 self.project_context.clone()
486 }
487
488 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
489
490 pub fn summary_or_default(&self) -> SharedString {
491 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
492 }
493
494 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
495 let Some(current_summary) = &self.summary else {
496 // Don't allow setting summary until generated
497 return;
498 };
499
500 let mut new_summary = new_summary.into();
501
502 if new_summary.is_empty() {
503 new_summary = Self::DEFAULT_SUMMARY;
504 }
505
506 if current_summary != &new_summary {
507 self.summary = Some(new_summary);
508 cx.emit(ThreadEvent::SummaryChanged);
509 }
510 }
511
512 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
513 self.latest_detailed_summary()
514 .unwrap_or_else(|| self.text().into())
515 }
516
517 fn latest_detailed_summary(&self) -> Option<SharedString> {
518 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
519 Some(text.clone())
520 } else {
521 None
522 }
523 }
524
525 pub fn completion_mode(&self) -> Option<CompletionMode> {
526 self.completion_mode
527 }
528
529 pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
530 self.completion_mode = mode;
531 }
532
533 pub fn message(&self, id: MessageId) -> Option<&Message> {
534 let index = self
535 .messages
536 .binary_search_by(|message| message.id.cmp(&id))
537 .ok()?;
538
539 self.messages.get(index)
540 }
541
542 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
543 self.messages.iter()
544 }
545
546 pub fn is_generating(&self) -> bool {
547 !self.pending_completions.is_empty() || !self.all_tools_finished()
548 }
549
550 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
551 &self.tools
552 }
553
554 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
555 self.tool_use
556 .pending_tool_uses()
557 .into_iter()
558 .find(|tool_use| &tool_use.id == id)
559 }
560
561 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
562 self.tool_use
563 .pending_tool_uses()
564 .into_iter()
565 .filter(|tool_use| tool_use.status.needs_confirmation())
566 }
567
568 pub fn has_pending_tool_uses(&self) -> bool {
569 !self.tool_use.pending_tool_uses().is_empty()
570 }
571
572 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
573 self.checkpoints_by_message.get(&id).cloned()
574 }
575
576 pub fn restore_checkpoint(
577 &mut self,
578 checkpoint: ThreadCheckpoint,
579 cx: &mut Context<Self>,
580 ) -> Task<Result<()>> {
581 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
582 message_id: checkpoint.message_id,
583 });
584 cx.emit(ThreadEvent::CheckpointChanged);
585 cx.notify();
586
587 let git_store = self.project().read(cx).git_store().clone();
588 let restore = git_store.update(cx, |git_store, cx| {
589 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
590 });
591
592 cx.spawn(async move |this, cx| {
593 let result = restore.await;
594 this.update(cx, |this, cx| {
595 if let Err(err) = result.as_ref() {
596 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
597 message_id: checkpoint.message_id,
598 error: err.to_string(),
599 });
600 } else {
601 this.truncate(checkpoint.message_id, cx);
602 this.last_restore_checkpoint = None;
603 }
604 this.pending_checkpoint = None;
605 cx.emit(ThreadEvent::CheckpointChanged);
606 cx.notify();
607 })?;
608 result
609 })
610 }
611
612 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
613 let pending_checkpoint = if self.is_generating() {
614 return;
615 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
616 checkpoint
617 } else {
618 return;
619 };
620
621 let git_store = self.project.read(cx).git_store().clone();
622 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
623 cx.spawn(async move |this, cx| match final_checkpoint.await {
624 Ok(final_checkpoint) => {
625 let equal = git_store
626 .update(cx, |store, cx| {
627 store.compare_checkpoints(
628 pending_checkpoint.git_checkpoint.clone(),
629 final_checkpoint.clone(),
630 cx,
631 )
632 })?
633 .await
634 .unwrap_or(false);
635
636 if !equal {
637 this.update(cx, |this, cx| {
638 this.insert_checkpoint(pending_checkpoint, cx)
639 })?;
640 }
641
642 Ok(())
643 }
644 Err(_) => this.update(cx, |this, cx| {
645 this.insert_checkpoint(pending_checkpoint, cx)
646 }),
647 })
648 .detach();
649 }
650
651 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
652 self.checkpoints_by_message
653 .insert(checkpoint.message_id, checkpoint);
654 cx.emit(ThreadEvent::CheckpointChanged);
655 cx.notify();
656 }
657
658 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
659 self.last_restore_checkpoint.as_ref()
660 }
661
662 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
663 let Some(message_ix) = self
664 .messages
665 .iter()
666 .rposition(|message| message.id == message_id)
667 else {
668 return;
669 };
670 for deleted_message in self.messages.drain(message_ix..) {
671 self.checkpoints_by_message.remove(&deleted_message.id);
672 }
673 cx.notify();
674 }
675
676 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
677 self.messages
678 .iter()
679 .find(|message| message.id == id)
680 .into_iter()
681 .flat_map(|message| message.loaded_context.contexts.iter())
682 }
683
684 pub fn is_turn_end(&self, ix: usize) -> bool {
685 if self.messages.is_empty() {
686 return false;
687 }
688
689 if !self.is_generating() && ix == self.messages.len() - 1 {
690 return true;
691 }
692
693 let Some(message) = self.messages.get(ix) else {
694 return false;
695 };
696
697 if message.role != Role::Assistant {
698 return false;
699 }
700
701 self.messages
702 .get(ix + 1)
703 .and_then(|message| {
704 self.message(message.id)
705 .map(|next_message| next_message.role == Role::User)
706 })
707 .unwrap_or(false)
708 }
709
710 /// Returns whether all of the tool uses have finished running.
711 pub fn all_tools_finished(&self) -> bool {
712 // If the only pending tool uses left are the ones with errors, then
713 // that means that we've finished running all of the pending tools.
714 self.tool_use
715 .pending_tool_uses()
716 .iter()
717 .all(|tool_use| tool_use.status.is_error())
718 }
719
720 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
721 self.tool_use.tool_uses_for_message(id, cx)
722 }
723
724 pub fn tool_results_for_message(
725 &self,
726 assistant_message_id: MessageId,
727 ) -> Vec<&LanguageModelToolResult> {
728 self.tool_use.tool_results_for_message(assistant_message_id)
729 }
730
731 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
732 self.tool_use.tool_result(id)
733 }
734
735 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
736 Some(&self.tool_use.tool_result(id)?.content)
737 }
738
739 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
740 self.tool_use.tool_result_card(id).cloned()
741 }
742
743 pub fn insert_user_message(
744 &mut self,
745 text: impl Into<String>,
746 loaded_context: ContextLoadResult,
747 git_checkpoint: Option<GitStoreCheckpoint>,
748 cx: &mut Context<Self>,
749 ) -> MessageId {
750 if !loaded_context.referenced_buffers.is_empty() {
751 self.action_log.update(cx, |log, cx| {
752 for buffer in loaded_context.referenced_buffers {
753 log.track_buffer(buffer, cx);
754 }
755 });
756 }
757
758 let message_id = self.insert_message(
759 Role::User,
760 vec![MessageSegment::Text(text.into())],
761 loaded_context.loaded_context,
762 cx,
763 );
764
765 if let Some(git_checkpoint) = git_checkpoint {
766 self.pending_checkpoint = Some(ThreadCheckpoint {
767 message_id,
768 git_checkpoint,
769 });
770 }
771
772 self.auto_capture_telemetry(cx);
773
774 message_id
775 }
776
777 pub fn insert_assistant_message(
778 &mut self,
779 segments: Vec<MessageSegment>,
780 cx: &mut Context<Self>,
781 ) -> MessageId {
782 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
783 }
784
785 pub fn insert_message(
786 &mut self,
787 role: Role,
788 segments: Vec<MessageSegment>,
789 loaded_context: LoadedContext,
790 cx: &mut Context<Self>,
791 ) -> MessageId {
792 let id = self.next_message_id.post_inc();
793 self.messages.push(Message {
794 id,
795 role,
796 segments,
797 loaded_context,
798 });
799 self.touch_updated_at();
800 cx.emit(ThreadEvent::MessageAdded(id));
801 id
802 }
803
804 pub fn edit_message(
805 &mut self,
806 id: MessageId,
807 new_role: Role,
808 new_segments: Vec<MessageSegment>,
809 cx: &mut Context<Self>,
810 ) -> bool {
811 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
812 return false;
813 };
814 message.role = new_role;
815 message.segments = new_segments;
816 self.touch_updated_at();
817 cx.emit(ThreadEvent::MessageEdited(id));
818 true
819 }
820
821 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
822 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
823 return false;
824 };
825 self.messages.remove(index);
826 self.touch_updated_at();
827 cx.emit(ThreadEvent::MessageDeleted(id));
828 true
829 }
830
831 /// Returns the representation of this [`Thread`] in a textual form.
832 ///
833 /// This is the representation we use when attaching a thread as context to another thread.
834 pub fn text(&self) -> String {
835 let mut text = String::new();
836
837 for message in &self.messages {
838 text.push_str(match message.role {
839 language_model::Role::User => "User:",
840 language_model::Role::Assistant => "Assistant:",
841 language_model::Role::System => "System:",
842 });
843 text.push('\n');
844
845 for segment in &message.segments {
846 match segment {
847 MessageSegment::Text(content) => text.push_str(content),
848 MessageSegment::Thinking { text: content, .. } => {
849 text.push_str(&format!("<think>{}</think>", content))
850 }
851 MessageSegment::RedactedThinking(_) => {}
852 }
853 }
854 text.push('\n');
855 }
856
857 text
858 }
859
860 /// Serializes this thread into a format for storage or telemetry.
861 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
862 let initial_project_snapshot = self.initial_project_snapshot.clone();
863 cx.spawn(async move |this, cx| {
864 let initial_project_snapshot = initial_project_snapshot.await;
865 this.read_with(cx, |this, cx| SerializedThread {
866 version: SerializedThread::VERSION.to_string(),
867 summary: this.summary_or_default(),
868 updated_at: this.updated_at(),
869 messages: this
870 .messages()
871 .map(|message| SerializedMessage {
872 id: message.id,
873 role: message.role,
874 segments: message
875 .segments
876 .iter()
877 .map(|segment| match segment {
878 MessageSegment::Text(text) => {
879 SerializedMessageSegment::Text { text: text.clone() }
880 }
881 MessageSegment::Thinking { text, signature } => {
882 SerializedMessageSegment::Thinking {
883 text: text.clone(),
884 signature: signature.clone(),
885 }
886 }
887 MessageSegment::RedactedThinking(data) => {
888 SerializedMessageSegment::RedactedThinking {
889 data: data.clone(),
890 }
891 }
892 })
893 .collect(),
894 tool_uses: this
895 .tool_uses_for_message(message.id, cx)
896 .into_iter()
897 .map(|tool_use| SerializedToolUse {
898 id: tool_use.id,
899 name: tool_use.name,
900 input: tool_use.input,
901 })
902 .collect(),
903 tool_results: this
904 .tool_results_for_message(message.id)
905 .into_iter()
906 .map(|tool_result| SerializedToolResult {
907 tool_use_id: tool_result.tool_use_id.clone(),
908 is_error: tool_result.is_error,
909 content: tool_result.content.clone(),
910 })
911 .collect(),
912 context: message.loaded_context.text.clone(),
913 })
914 .collect(),
915 initial_project_snapshot,
916 cumulative_token_usage: this.cumulative_token_usage,
917 request_token_usage: this.request_token_usage.clone(),
918 detailed_summary_state: this.detailed_summary_state.clone(),
919 exceeded_window_error: this.exceeded_window_error.clone(),
920 })
921 })
922 }
923
924 pub fn remaining_turns(&self) -> u32 {
925 self.remaining_turns
926 }
927
928 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
929 self.remaining_turns = remaining_turns;
930 }
931
932 pub fn send_to_model(
933 &mut self,
934 model: Arc<dyn LanguageModel>,
935 window: Option<AnyWindowHandle>,
936 cx: &mut Context<Self>,
937 ) {
938 if self.remaining_turns == 0 {
939 return;
940 }
941
942 self.remaining_turns -= 1;
943
944 let mut request = self.to_completion_request(cx);
945 request.mode = if model.supports_max_mode() {
946 self.completion_mode
947 } else {
948 None
949 };
950
951 if model.supports_tools() {
952 request.tools = self
953 .tools()
954 .read(cx)
955 .enabled_tools(cx)
956 .into_iter()
957 .filter_map(|tool| {
958 // Skip tools that cannot be supported
959 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
960 Some(LanguageModelRequestTool {
961 name: tool.name(),
962 description: tool.description(),
963 input_schema,
964 })
965 })
966 .collect();
967 }
968
969 self.stream_completion(request, model, window, cx);
970 }
971
972 pub fn used_tools_since_last_user_message(&self) -> bool {
973 for message in self.messages.iter().rev() {
974 if self.tool_use.message_has_tool_results(message.id) {
975 return true;
976 } else if message.role == Role::User {
977 return false;
978 }
979 }
980
981 false
982 }
983
984 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
985 let mut request = LanguageModelRequest {
986 thread_id: Some(self.id.to_string()),
987 prompt_id: Some(self.last_prompt_id.to_string()),
988 mode: None,
989 messages: vec![],
990 tools: Vec::new(),
991 stop: Vec::new(),
992 temperature: None,
993 };
994
995 if let Some(project_context) = self.project_context.borrow().as_ref() {
996 match self
997 .prompt_builder
998 .generate_assistant_system_prompt(project_context)
999 {
1000 Err(err) => {
1001 let message = format!("{err:?}").into();
1002 log::error!("{message}");
1003 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1004 header: "Error generating system prompt".into(),
1005 message,
1006 }));
1007 }
1008 Ok(system_prompt) => {
1009 request.messages.push(LanguageModelRequestMessage {
1010 role: Role::System,
1011 content: vec![MessageContent::Text(system_prompt)],
1012 cache: true,
1013 });
1014 }
1015 }
1016 } else {
1017 let message = "Context for system prompt unexpectedly not ready.".into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024
1025 for message in &self.messages {
1026 let mut request_message = LanguageModelRequestMessage {
1027 role: message.role,
1028 content: Vec::new(),
1029 cache: false,
1030 };
1031
1032 message
1033 .loaded_context
1034 .add_to_request_message(&mut request_message);
1035
1036 for segment in &message.segments {
1037 match segment {
1038 MessageSegment::Text(text) => {
1039 if !text.is_empty() {
1040 request_message
1041 .content
1042 .push(MessageContent::Text(text.into()));
1043 }
1044 }
1045 MessageSegment::Thinking { text, signature } => {
1046 if !text.is_empty() {
1047 request_message.content.push(MessageContent::Thinking {
1048 text: text.into(),
1049 signature: signature.clone(),
1050 });
1051 }
1052 }
1053 MessageSegment::RedactedThinking(data) => {
1054 request_message
1055 .content
1056 .push(MessageContent::RedactedThinking(data.clone()));
1057 }
1058 };
1059 }
1060
1061 self.tool_use
1062 .attach_tool_uses(message.id, &mut request_message);
1063
1064 request.messages.push(request_message);
1065
1066 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1067 request.messages.push(tool_results_message);
1068 }
1069 }
1070
1071 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1072 if let Some(last) = request.messages.last_mut() {
1073 last.cache = true;
1074 }
1075
1076 self.attached_tracked_files_state(&mut request.messages, cx);
1077
1078 request
1079 }
1080
1081 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1082 let mut request = LanguageModelRequest {
1083 thread_id: None,
1084 prompt_id: None,
1085 mode: None,
1086 messages: vec![],
1087 tools: Vec::new(),
1088 stop: Vec::new(),
1089 temperature: None,
1090 };
1091
1092 for message in &self.messages {
1093 let mut request_message = LanguageModelRequestMessage {
1094 role: message.role,
1095 content: Vec::new(),
1096 cache: false,
1097 };
1098
1099 for segment in &message.segments {
1100 match segment {
1101 MessageSegment::Text(text) => request_message
1102 .content
1103 .push(MessageContent::Text(text.clone())),
1104 MessageSegment::Thinking { .. } => {}
1105 MessageSegment::RedactedThinking(_) => {}
1106 }
1107 }
1108
1109 if request_message.content.is_empty() {
1110 continue;
1111 }
1112
1113 request.messages.push(request_message);
1114 }
1115
1116 request.messages.push(LanguageModelRequestMessage {
1117 role: Role::User,
1118 content: vec![MessageContent::Text(added_user_message)],
1119 cache: false,
1120 });
1121
1122 request
1123 }
1124
1125 fn attached_tracked_files_state(
1126 &self,
1127 messages: &mut Vec<LanguageModelRequestMessage>,
1128 cx: &App,
1129 ) {
1130 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1131
1132 let mut stale_message = String::new();
1133
1134 let action_log = self.action_log.read(cx);
1135
1136 for stale_file in action_log.stale_buffers(cx) {
1137 let Some(file) = stale_file.read(cx).file() else {
1138 continue;
1139 };
1140
1141 if stale_message.is_empty() {
1142 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1143 }
1144
1145 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1146 }
1147
1148 let mut content = Vec::with_capacity(2);
1149
1150 if !stale_message.is_empty() {
1151 content.push(stale_message.into());
1152 }
1153
1154 if !content.is_empty() {
1155 let context_message = LanguageModelRequestMessage {
1156 role: Role::User,
1157 content,
1158 cache: false,
1159 };
1160
1161 messages.push(context_message);
1162 }
1163 }
1164
1165 pub fn stream_completion(
1166 &mut self,
1167 request: LanguageModelRequest,
1168 model: Arc<dyn LanguageModel>,
1169 window: Option<AnyWindowHandle>,
1170 cx: &mut Context<Self>,
1171 ) {
1172 let pending_completion_id = post_inc(&mut self.completion_count);
1173 let mut request_callback_parameters = if self.request_callback.is_some() {
1174 Some((request.clone(), Vec::new()))
1175 } else {
1176 None
1177 };
1178 let prompt_id = self.last_prompt_id.clone();
1179 let tool_use_metadata = ToolUseMetadata {
1180 model: model.clone(),
1181 thread_id: self.id.clone(),
1182 prompt_id: prompt_id.clone(),
1183 };
1184
1185 let task = cx.spawn(async move |thread, cx| {
1186 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1187 let initial_token_usage =
1188 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1189 let stream_completion = async {
1190 let (mut events, usage) = stream_completion_future.await?;
1191
1192 let mut stop_reason = StopReason::EndTurn;
1193 let mut current_token_usage = TokenUsage::default();
1194
1195 if let Some(usage) = usage {
1196 thread
1197 .update(cx, |_thread, cx| {
1198 cx.emit(ThreadEvent::UsageUpdated(usage));
1199 })
1200 .ok();
1201 }
1202
1203 let mut request_assistant_message_id = None;
1204
1205 while let Some(event) = events.next().await {
1206 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1207 response_events
1208 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1209 }
1210
1211 thread.update(cx, |thread, cx| {
1212 let event = match event {
1213 Ok(event) => event,
1214 Err(LanguageModelCompletionError::BadInputJson {
1215 id,
1216 tool_name,
1217 raw_input: invalid_input_json,
1218 json_parse_error,
1219 }) => {
1220 thread.receive_invalid_tool_json(
1221 id,
1222 tool_name,
1223 invalid_input_json,
1224 json_parse_error,
1225 window,
1226 cx,
1227 );
1228 return Ok(());
1229 }
1230 Err(LanguageModelCompletionError::Other(error)) => {
1231 return Err(error);
1232 }
1233 };
1234
1235 match event {
1236 LanguageModelCompletionEvent::StartMessage { .. } => {
1237 request_assistant_message_id =
1238 Some(thread.insert_assistant_message(
1239 vec![MessageSegment::Text(String::new())],
1240 cx,
1241 ));
1242 }
1243 LanguageModelCompletionEvent::Stop(reason) => {
1244 stop_reason = reason;
1245 }
1246 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1247 thread.update_token_usage_at_last_message(token_usage);
1248 thread.cumulative_token_usage = thread.cumulative_token_usage
1249 + token_usage
1250 - current_token_usage;
1251 current_token_usage = token_usage;
1252 }
1253 LanguageModelCompletionEvent::Text(chunk) => {
1254 cx.emit(ThreadEvent::ReceivedTextChunk);
1255 if let Some(last_message) = thread.messages.last_mut() {
1256 if last_message.role == Role::Assistant
1257 && !thread.tool_use.has_tool_results(last_message.id)
1258 {
1259 last_message.push_text(&chunk);
1260 cx.emit(ThreadEvent::StreamedAssistantText(
1261 last_message.id,
1262 chunk,
1263 ));
1264 } else {
1265 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1266 // of a new Assistant response.
1267 //
1268 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1269 // will result in duplicating the text of the chunk in the rendered Markdown.
1270 request_assistant_message_id =
1271 Some(thread.insert_assistant_message(
1272 vec![MessageSegment::Text(chunk.to_string())],
1273 cx,
1274 ));
1275 };
1276 }
1277 }
1278 LanguageModelCompletionEvent::Thinking {
1279 text: chunk,
1280 signature,
1281 } => {
1282 if let Some(last_message) = thread.messages.last_mut() {
1283 if last_message.role == Role::Assistant
1284 && !thread.tool_use.has_tool_results(last_message.id)
1285 {
1286 last_message.push_thinking(&chunk, signature);
1287 cx.emit(ThreadEvent::StreamedAssistantThinking(
1288 last_message.id,
1289 chunk,
1290 ));
1291 } else {
1292 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1293 // of a new Assistant response.
1294 //
1295 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1296 // will result in duplicating the text of the chunk in the rendered Markdown.
1297 request_assistant_message_id =
1298 Some(thread.insert_assistant_message(
1299 vec![MessageSegment::Thinking {
1300 text: chunk.to_string(),
1301 signature,
1302 }],
1303 cx,
1304 ));
1305 };
1306 }
1307 }
1308 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1309 let last_assistant_message_id = request_assistant_message_id
1310 .unwrap_or_else(|| {
1311 let new_assistant_message_id =
1312 thread.insert_assistant_message(vec![], cx);
1313 request_assistant_message_id =
1314 Some(new_assistant_message_id);
1315 new_assistant_message_id
1316 });
1317
1318 let tool_use_id = tool_use.id.clone();
1319 let streamed_input = if tool_use.is_input_complete {
1320 None
1321 } else {
1322 Some((&tool_use.input).clone())
1323 };
1324
1325 let ui_text = thread.tool_use.request_tool_use(
1326 last_assistant_message_id,
1327 tool_use,
1328 tool_use_metadata.clone(),
1329 cx,
1330 );
1331
1332 if let Some(input) = streamed_input {
1333 cx.emit(ThreadEvent::StreamedToolUse {
1334 tool_use_id,
1335 ui_text,
1336 input,
1337 });
1338 }
1339 }
1340 }
1341
1342 thread.touch_updated_at();
1343 cx.emit(ThreadEvent::StreamedCompletion);
1344 cx.notify();
1345
1346 thread.auto_capture_telemetry(cx);
1347 Ok(())
1348 })??;
1349
1350 smol::future::yield_now().await;
1351 }
1352
1353 thread.update(cx, |thread, cx| {
1354 thread
1355 .pending_completions
1356 .retain(|completion| completion.id != pending_completion_id);
1357
1358 // If there is a response without tool use, summarize the message. Otherwise,
1359 // allow two tool uses before summarizing.
1360 if thread.summary.is_none()
1361 && thread.messages.len() >= 2
1362 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1363 {
1364 thread.summarize(cx);
1365 }
1366 })?;
1367
1368 anyhow::Ok(stop_reason)
1369 };
1370
1371 let result = stream_completion.await;
1372
1373 thread
1374 .update(cx, |thread, cx| {
1375 thread.finalize_pending_checkpoint(cx);
1376 match result.as_ref() {
1377 Ok(stop_reason) => match stop_reason {
1378 StopReason::ToolUse => {
1379 let tool_uses = thread.use_pending_tools(window, cx);
1380 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1381 }
1382 StopReason::EndTurn => {}
1383 StopReason::MaxTokens => {}
1384 },
1385 Err(error) => {
1386 if error.is::<PaymentRequiredError>() {
1387 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1388 } else if error.is::<MaxMonthlySpendReachedError>() {
1389 cx.emit(ThreadEvent::ShowError(
1390 ThreadError::MaxMonthlySpendReached,
1391 ));
1392 } else if let Some(error) =
1393 error.downcast_ref::<ModelRequestLimitReachedError>()
1394 {
1395 cx.emit(ThreadEvent::ShowError(
1396 ThreadError::ModelRequestLimitReached { plan: error.plan },
1397 ));
1398 } else if let Some(known_error) =
1399 error.downcast_ref::<LanguageModelKnownError>()
1400 {
1401 match known_error {
1402 LanguageModelKnownError::ContextWindowLimitExceeded {
1403 tokens,
1404 } => {
1405 thread.exceeded_window_error = Some(ExceededWindowError {
1406 model_id: model.id(),
1407 token_count: *tokens,
1408 });
1409 cx.notify();
1410 }
1411 }
1412 } else {
1413 let error_message = error
1414 .chain()
1415 .map(|err| err.to_string())
1416 .collect::<Vec<_>>()
1417 .join("\n");
1418 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1419 header: "Error interacting with language model".into(),
1420 message: SharedString::from(error_message.clone()),
1421 }));
1422 }
1423
1424 thread.cancel_last_completion(window, cx);
1425 }
1426 }
1427 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1428
1429 if let Some((request_callback, (request, response_events))) = thread
1430 .request_callback
1431 .as_mut()
1432 .zip(request_callback_parameters.as_ref())
1433 {
1434 request_callback(request, response_events);
1435 }
1436
1437 thread.auto_capture_telemetry(cx);
1438
1439 if let Ok(initial_usage) = initial_token_usage {
1440 let usage = thread.cumulative_token_usage - initial_usage;
1441
1442 telemetry::event!(
1443 "Assistant Thread Completion",
1444 thread_id = thread.id().to_string(),
1445 prompt_id = prompt_id,
1446 model = model.telemetry_id(),
1447 model_provider = model.provider_id().to_string(),
1448 input_tokens = usage.input_tokens,
1449 output_tokens = usage.output_tokens,
1450 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1451 cache_read_input_tokens = usage.cache_read_input_tokens,
1452 );
1453 }
1454 })
1455 .ok();
1456 });
1457
1458 self.pending_completions.push(PendingCompletion {
1459 id: pending_completion_id,
1460 _task: task,
1461 });
1462 }
1463
1464 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1465 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1466 return;
1467 };
1468
1469 if !model.provider.is_authenticated(cx) {
1470 return;
1471 }
1472
1473 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1474 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1475 If the conversation is about a specific subject, include it in the title. \
1476 Be descriptive. DO NOT speak in the first person.";
1477
1478 let request = self.to_summarize_request(added_user_message.into());
1479
1480 self.pending_summary = cx.spawn(async move |this, cx| {
1481 async move {
1482 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1483 let (mut messages, usage) = stream.await?;
1484
1485 if let Some(usage) = usage {
1486 this.update(cx, |_thread, cx| {
1487 cx.emit(ThreadEvent::UsageUpdated(usage));
1488 })
1489 .ok();
1490 }
1491
1492 let mut new_summary = String::new();
1493 while let Some(message) = messages.stream.next().await {
1494 let text = message?;
1495 let mut lines = text.lines();
1496 new_summary.extend(lines.next());
1497
1498 // Stop if the LLM generated multiple lines.
1499 if lines.next().is_some() {
1500 break;
1501 }
1502 }
1503
1504 this.update(cx, |this, cx| {
1505 if !new_summary.is_empty() {
1506 this.summary = Some(new_summary.into());
1507 }
1508
1509 cx.emit(ThreadEvent::SummaryGenerated);
1510 })?;
1511
1512 anyhow::Ok(())
1513 }
1514 .log_err()
1515 .await
1516 });
1517 }
1518
1519 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1520 let last_message_id = self.messages.last().map(|message| message.id)?;
1521
1522 match &self.detailed_summary_state {
1523 DetailedSummaryState::Generating { message_id, .. }
1524 | DetailedSummaryState::Generated { message_id, .. }
1525 if *message_id == last_message_id =>
1526 {
1527 // Already up-to-date
1528 return None;
1529 }
1530 _ => {}
1531 }
1532
1533 let ConfiguredModel { model, provider } =
1534 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1535
1536 if !provider.is_authenticated(cx) {
1537 return None;
1538 }
1539
1540 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1541 1. A brief overview of what was discussed\n\
1542 2. Key facts or information discovered\n\
1543 3. Outcomes or conclusions reached\n\
1544 4. Any action items or next steps if any\n\
1545 Format it in Markdown with headings and bullet points.";
1546
1547 let request = self.to_summarize_request(added_user_message.into());
1548
1549 let task = cx.spawn(async move |thread, cx| {
1550 let stream = model.stream_completion_text(request, &cx);
1551 let Some(mut messages) = stream.await.log_err() else {
1552 thread
1553 .update(cx, |this, _cx| {
1554 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1555 })
1556 .log_err();
1557
1558 return;
1559 };
1560
1561 let mut new_detailed_summary = String::new();
1562
1563 while let Some(chunk) = messages.stream.next().await {
1564 if let Some(chunk) = chunk.log_err() {
1565 new_detailed_summary.push_str(&chunk);
1566 }
1567 }
1568
1569 thread
1570 .update(cx, |this, _cx| {
1571 this.detailed_summary_state = DetailedSummaryState::Generated {
1572 text: new_detailed_summary.into(),
1573 message_id: last_message_id,
1574 };
1575 })
1576 .log_err();
1577 });
1578
1579 self.detailed_summary_state = DetailedSummaryState::Generating {
1580 message_id: last_message_id,
1581 };
1582
1583 Some(task)
1584 }
1585
1586 pub fn is_generating_detailed_summary(&self) -> bool {
1587 matches!(
1588 self.detailed_summary_state,
1589 DetailedSummaryState::Generating { .. }
1590 )
1591 }
1592
1593 pub fn use_pending_tools(
1594 &mut self,
1595 window: Option<AnyWindowHandle>,
1596 cx: &mut Context<Self>,
1597 ) -> Vec<PendingToolUse> {
1598 self.auto_capture_telemetry(cx);
1599 let request = self.to_completion_request(cx);
1600 let messages = Arc::new(request.messages);
1601 let pending_tool_uses = self
1602 .tool_use
1603 .pending_tool_uses()
1604 .into_iter()
1605 .filter(|tool_use| tool_use.status.is_idle())
1606 .cloned()
1607 .collect::<Vec<_>>();
1608
1609 for tool_use in pending_tool_uses.iter() {
1610 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1611 if tool.needs_confirmation(&tool_use.input, cx)
1612 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1613 {
1614 self.tool_use.confirm_tool_use(
1615 tool_use.id.clone(),
1616 tool_use.ui_text.clone(),
1617 tool_use.input.clone(),
1618 messages.clone(),
1619 tool,
1620 );
1621 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1622 } else {
1623 self.run_tool(
1624 tool_use.id.clone(),
1625 tool_use.ui_text.clone(),
1626 tool_use.input.clone(),
1627 &messages,
1628 tool,
1629 window,
1630 cx,
1631 );
1632 }
1633 }
1634 }
1635
1636 pending_tool_uses
1637 }
1638
1639 pub fn receive_invalid_tool_json(
1640 &mut self,
1641 tool_use_id: LanguageModelToolUseId,
1642 tool_name: Arc<str>,
1643 invalid_json: Arc<str>,
1644 error: String,
1645 window: Option<AnyWindowHandle>,
1646 cx: &mut Context<Thread>,
1647 ) {
1648 log::error!("The model returned invalid input JSON: {invalid_json}");
1649
1650 let pending_tool_use = self.tool_use.insert_tool_output(
1651 tool_use_id.clone(),
1652 tool_name,
1653 Err(anyhow!("Error parsing input JSON: {error}")),
1654 cx,
1655 );
1656 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1657 pending_tool_use.ui_text.clone()
1658 } else {
1659 log::error!(
1660 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1661 );
1662 format!("Unknown tool {}", tool_use_id).into()
1663 };
1664
1665 cx.emit(ThreadEvent::InvalidToolInput {
1666 tool_use_id: tool_use_id.clone(),
1667 ui_text,
1668 invalid_input_json: invalid_json,
1669 });
1670
1671 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1672 }
1673
1674 pub fn run_tool(
1675 &mut self,
1676 tool_use_id: LanguageModelToolUseId,
1677 ui_text: impl Into<SharedString>,
1678 input: serde_json::Value,
1679 messages: &[LanguageModelRequestMessage],
1680 tool: Arc<dyn Tool>,
1681 window: Option<AnyWindowHandle>,
1682 cx: &mut Context<Thread>,
1683 ) {
1684 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1685 self.tool_use
1686 .run_pending_tool(tool_use_id, ui_text.into(), task);
1687 }
1688
1689 fn spawn_tool_use(
1690 &mut self,
1691 tool_use_id: LanguageModelToolUseId,
1692 messages: &[LanguageModelRequestMessage],
1693 input: serde_json::Value,
1694 tool: Arc<dyn Tool>,
1695 window: Option<AnyWindowHandle>,
1696 cx: &mut Context<Thread>,
1697 ) -> Task<()> {
1698 let tool_name: Arc<str> = tool.name().into();
1699
1700 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1701 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1702 } else {
1703 tool.run(
1704 input,
1705 messages,
1706 self.project.clone(),
1707 self.action_log.clone(),
1708 window,
1709 cx,
1710 )
1711 };
1712
1713 // Store the card separately if it exists
1714 if let Some(card) = tool_result.card.clone() {
1715 self.tool_use
1716 .insert_tool_result_card(tool_use_id.clone(), card);
1717 }
1718
1719 cx.spawn({
1720 async move |thread: WeakEntity<Thread>, cx| {
1721 let output = tool_result.output.await;
1722
1723 thread
1724 .update(cx, |thread, cx| {
1725 let pending_tool_use = thread.tool_use.insert_tool_output(
1726 tool_use_id.clone(),
1727 tool_name,
1728 output,
1729 cx,
1730 );
1731 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1732 })
1733 .ok();
1734 }
1735 })
1736 }
1737
1738 fn tool_finished(
1739 &mut self,
1740 tool_use_id: LanguageModelToolUseId,
1741 pending_tool_use: Option<PendingToolUse>,
1742 canceled: bool,
1743 window: Option<AnyWindowHandle>,
1744 cx: &mut Context<Self>,
1745 ) {
1746 if self.all_tools_finished() {
1747 let model_registry = LanguageModelRegistry::read_global(cx);
1748 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1749 if !canceled {
1750 self.send_to_model(model, window, cx);
1751 }
1752 self.auto_capture_telemetry(cx);
1753 }
1754 }
1755
1756 cx.emit(ThreadEvent::ToolFinished {
1757 tool_use_id,
1758 pending_tool_use,
1759 });
1760 }
1761
1762 /// Cancels the last pending completion, if there are any pending.
1763 ///
1764 /// Returns whether a completion was canceled.
1765 pub fn cancel_last_completion(
1766 &mut self,
1767 window: Option<AnyWindowHandle>,
1768 cx: &mut Context<Self>,
1769 ) -> bool {
1770 let mut canceled = self.pending_completions.pop().is_some();
1771
1772 for pending_tool_use in self.tool_use.cancel_pending() {
1773 canceled = true;
1774 self.tool_finished(
1775 pending_tool_use.id.clone(),
1776 Some(pending_tool_use),
1777 true,
1778 window,
1779 cx,
1780 );
1781 }
1782
1783 self.finalize_pending_checkpoint(cx);
1784 canceled
1785 }
1786
1787 pub fn feedback(&self) -> Option<ThreadFeedback> {
1788 self.feedback
1789 }
1790
1791 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1792 self.message_feedback.get(&message_id).copied()
1793 }
1794
1795 pub fn report_message_feedback(
1796 &mut self,
1797 message_id: MessageId,
1798 feedback: ThreadFeedback,
1799 cx: &mut Context<Self>,
1800 ) -> Task<Result<()>> {
1801 if self.message_feedback.get(&message_id) == Some(&feedback) {
1802 return Task::ready(Ok(()));
1803 }
1804
1805 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1806 let serialized_thread = self.serialize(cx);
1807 let thread_id = self.id().clone();
1808 let client = self.project.read(cx).client();
1809
1810 let enabled_tool_names: Vec<String> = self
1811 .tools()
1812 .read(cx)
1813 .enabled_tools(cx)
1814 .iter()
1815 .map(|tool| tool.name().to_string())
1816 .collect();
1817
1818 self.message_feedback.insert(message_id, feedback);
1819
1820 cx.notify();
1821
1822 let message_content = self
1823 .message(message_id)
1824 .map(|msg| msg.to_string())
1825 .unwrap_or_default();
1826
1827 cx.background_spawn(async move {
1828 let final_project_snapshot = final_project_snapshot.await;
1829 let serialized_thread = serialized_thread.await?;
1830 let thread_data =
1831 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1832
1833 let rating = match feedback {
1834 ThreadFeedback::Positive => "positive",
1835 ThreadFeedback::Negative => "negative",
1836 };
1837 telemetry::event!(
1838 "Assistant Thread Rated",
1839 rating,
1840 thread_id,
1841 enabled_tool_names,
1842 message_id = message_id.0,
1843 message_content,
1844 thread_data,
1845 final_project_snapshot
1846 );
1847 client.telemetry().flush_events().await;
1848
1849 Ok(())
1850 })
1851 }
1852
1853 pub fn report_feedback(
1854 &mut self,
1855 feedback: ThreadFeedback,
1856 cx: &mut Context<Self>,
1857 ) -> Task<Result<()>> {
1858 let last_assistant_message_id = self
1859 .messages
1860 .iter()
1861 .rev()
1862 .find(|msg| msg.role == Role::Assistant)
1863 .map(|msg| msg.id);
1864
1865 if let Some(message_id) = last_assistant_message_id {
1866 self.report_message_feedback(message_id, feedback, cx)
1867 } else {
1868 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1869 let serialized_thread = self.serialize(cx);
1870 let thread_id = self.id().clone();
1871 let client = self.project.read(cx).client();
1872 self.feedback = Some(feedback);
1873 cx.notify();
1874
1875 cx.background_spawn(async move {
1876 let final_project_snapshot = final_project_snapshot.await;
1877 let serialized_thread = serialized_thread.await?;
1878 let thread_data = serde_json::to_value(serialized_thread)
1879 .unwrap_or_else(|_| serde_json::Value::Null);
1880
1881 let rating = match feedback {
1882 ThreadFeedback::Positive => "positive",
1883 ThreadFeedback::Negative => "negative",
1884 };
1885 telemetry::event!(
1886 "Assistant Thread Rated",
1887 rating,
1888 thread_id,
1889 thread_data,
1890 final_project_snapshot
1891 );
1892 client.telemetry().flush_events().await;
1893
1894 Ok(())
1895 })
1896 }
1897 }
1898
1899 /// Create a snapshot of the current project state including git information and unsaved buffers.
1900 fn project_snapshot(
1901 project: Entity<Project>,
1902 cx: &mut Context<Self>,
1903 ) -> Task<Arc<ProjectSnapshot>> {
1904 let git_store = project.read(cx).git_store().clone();
1905 let worktree_snapshots: Vec<_> = project
1906 .read(cx)
1907 .visible_worktrees(cx)
1908 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1909 .collect();
1910
1911 cx.spawn(async move |_, cx| {
1912 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1913
1914 let mut unsaved_buffers = Vec::new();
1915 cx.update(|app_cx| {
1916 let buffer_store = project.read(app_cx).buffer_store();
1917 for buffer_handle in buffer_store.read(app_cx).buffers() {
1918 let buffer = buffer_handle.read(app_cx);
1919 if buffer.is_dirty() {
1920 if let Some(file) = buffer.file() {
1921 let path = file.path().to_string_lossy().to_string();
1922 unsaved_buffers.push(path);
1923 }
1924 }
1925 }
1926 })
1927 .ok();
1928
1929 Arc::new(ProjectSnapshot {
1930 worktree_snapshots,
1931 unsaved_buffer_paths: unsaved_buffers,
1932 timestamp: Utc::now(),
1933 })
1934 })
1935 }
1936
1937 fn worktree_snapshot(
1938 worktree: Entity<project::Worktree>,
1939 git_store: Entity<GitStore>,
1940 cx: &App,
1941 ) -> Task<WorktreeSnapshot> {
1942 cx.spawn(async move |cx| {
1943 // Get worktree path and snapshot
1944 let worktree_info = cx.update(|app_cx| {
1945 let worktree = worktree.read(app_cx);
1946 let path = worktree.abs_path().to_string_lossy().to_string();
1947 let snapshot = worktree.snapshot();
1948 (path, snapshot)
1949 });
1950
1951 let Ok((worktree_path, _snapshot)) = worktree_info else {
1952 return WorktreeSnapshot {
1953 worktree_path: String::new(),
1954 git_state: None,
1955 };
1956 };
1957
1958 let git_state = git_store
1959 .update(cx, |git_store, cx| {
1960 git_store
1961 .repositories()
1962 .values()
1963 .find(|repo| {
1964 repo.read(cx)
1965 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1966 .is_some()
1967 })
1968 .cloned()
1969 })
1970 .ok()
1971 .flatten()
1972 .map(|repo| {
1973 repo.update(cx, |repo, _| {
1974 let current_branch =
1975 repo.branch.as_ref().map(|branch| branch.name.to_string());
1976 repo.send_job(None, |state, _| async move {
1977 let RepositoryState::Local { backend, .. } = state else {
1978 return GitState {
1979 remote_url: None,
1980 head_sha: None,
1981 current_branch,
1982 diff: None,
1983 };
1984 };
1985
1986 let remote_url = backend.remote_url("origin");
1987 let head_sha = backend.head_sha().await;
1988 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1989
1990 GitState {
1991 remote_url,
1992 head_sha,
1993 current_branch,
1994 diff,
1995 }
1996 })
1997 })
1998 });
1999
2000 let git_state = match git_state {
2001 Some(git_state) => match git_state.ok() {
2002 Some(git_state) => git_state.await.ok(),
2003 None => None,
2004 },
2005 None => None,
2006 };
2007
2008 WorktreeSnapshot {
2009 worktree_path,
2010 git_state,
2011 }
2012 })
2013 }
2014
2015 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2016 let mut markdown = Vec::new();
2017
2018 if let Some(summary) = self.summary() {
2019 writeln!(markdown, "# {summary}\n")?;
2020 };
2021
2022 for message in self.messages() {
2023 writeln!(
2024 markdown,
2025 "## {role}\n",
2026 role = match message.role {
2027 Role::User => "User",
2028 Role::Assistant => "Assistant",
2029 Role::System => "System",
2030 }
2031 )?;
2032
2033 if !message.loaded_context.text.is_empty() {
2034 writeln!(markdown, "{}", message.loaded_context.text)?;
2035 }
2036
2037 if !message.loaded_context.images.is_empty() {
2038 writeln!(
2039 markdown,
2040 "\n{} images attached as context.\n",
2041 message.loaded_context.images.len()
2042 )?;
2043 }
2044
2045 for segment in &message.segments {
2046 match segment {
2047 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2048 MessageSegment::Thinking { text, .. } => {
2049 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2050 }
2051 MessageSegment::RedactedThinking(_) => {}
2052 }
2053 }
2054
2055 for tool_use in self.tool_uses_for_message(message.id, cx) {
2056 writeln!(
2057 markdown,
2058 "**Use Tool: {} ({})**",
2059 tool_use.name, tool_use.id
2060 )?;
2061 writeln!(markdown, "```json")?;
2062 writeln!(
2063 markdown,
2064 "{}",
2065 serde_json::to_string_pretty(&tool_use.input)?
2066 )?;
2067 writeln!(markdown, "```")?;
2068 }
2069
2070 for tool_result in self.tool_results_for_message(message.id) {
2071 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2072 if tool_result.is_error {
2073 write!(markdown, " (Error)")?;
2074 }
2075
2076 writeln!(markdown, "**\n")?;
2077 writeln!(markdown, "{}", tool_result.content)?;
2078 }
2079 }
2080
2081 Ok(String::from_utf8_lossy(&markdown).to_string())
2082 }
2083
2084 pub fn keep_edits_in_range(
2085 &mut self,
2086 buffer: Entity<language::Buffer>,
2087 buffer_range: Range<language::Anchor>,
2088 cx: &mut Context<Self>,
2089 ) {
2090 self.action_log.update(cx, |action_log, cx| {
2091 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2092 });
2093 }
2094
2095 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2096 self.action_log
2097 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2098 }
2099
2100 pub fn reject_edits_in_ranges(
2101 &mut self,
2102 buffer: Entity<language::Buffer>,
2103 buffer_ranges: Vec<Range<language::Anchor>>,
2104 cx: &mut Context<Self>,
2105 ) -> Task<Result<()>> {
2106 self.action_log.update(cx, |action_log, cx| {
2107 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2108 })
2109 }
2110
2111 pub fn action_log(&self) -> &Entity<ActionLog> {
2112 &self.action_log
2113 }
2114
2115 pub fn project(&self) -> &Entity<Project> {
2116 &self.project
2117 }
2118
2119 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2120 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2121 return;
2122 }
2123
2124 let now = Instant::now();
2125 if let Some(last) = self.last_auto_capture_at {
2126 if now.duration_since(last).as_secs() < 10 {
2127 return;
2128 }
2129 }
2130
2131 self.last_auto_capture_at = Some(now);
2132
2133 let thread_id = self.id().clone();
2134 let github_login = self
2135 .project
2136 .read(cx)
2137 .user_store()
2138 .read(cx)
2139 .current_user()
2140 .map(|user| user.github_login.clone());
2141 let client = self.project.read(cx).client().clone();
2142 let serialize_task = self.serialize(cx);
2143
2144 cx.background_executor()
2145 .spawn(async move {
2146 if let Ok(serialized_thread) = serialize_task.await {
2147 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2148 telemetry::event!(
2149 "Agent Thread Auto-Captured",
2150 thread_id = thread_id.to_string(),
2151 thread_data = thread_data,
2152 auto_capture_reason = "tracked_user",
2153 github_login = github_login
2154 );
2155
2156 client.telemetry().flush_events().await;
2157 }
2158 }
2159 })
2160 .detach();
2161 }
2162
2163 pub fn cumulative_token_usage(&self) -> TokenUsage {
2164 self.cumulative_token_usage
2165 }
2166
2167 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2168 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2169 return TotalTokenUsage::default();
2170 };
2171
2172 let max = model.model.max_token_count();
2173
2174 let index = self
2175 .messages
2176 .iter()
2177 .position(|msg| msg.id == message_id)
2178 .unwrap_or(0);
2179
2180 if index == 0 {
2181 return TotalTokenUsage { total: 0, max };
2182 }
2183
2184 let token_usage = &self
2185 .request_token_usage
2186 .get(index - 1)
2187 .cloned()
2188 .unwrap_or_default();
2189
2190 TotalTokenUsage {
2191 total: token_usage.total_tokens() as usize,
2192 max,
2193 }
2194 }
2195
2196 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2197 let model_registry = LanguageModelRegistry::read_global(cx);
2198 let Some(model) = model_registry.default_model() else {
2199 return TotalTokenUsage::default();
2200 };
2201
2202 let max = model.model.max_token_count();
2203
2204 if let Some(exceeded_error) = &self.exceeded_window_error {
2205 if model.model.id() == exceeded_error.model_id {
2206 return TotalTokenUsage {
2207 total: exceeded_error.token_count,
2208 max,
2209 };
2210 }
2211 }
2212
2213 let total = self
2214 .token_usage_at_last_message()
2215 .unwrap_or_default()
2216 .total_tokens() as usize;
2217
2218 TotalTokenUsage { total, max }
2219 }
2220
2221 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2222 self.request_token_usage
2223 .get(self.messages.len().saturating_sub(1))
2224 .or_else(|| self.request_token_usage.last())
2225 .cloned()
2226 }
2227
2228 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2229 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2230 self.request_token_usage
2231 .resize(self.messages.len(), placeholder);
2232
2233 if let Some(last) = self.request_token_usage.last_mut() {
2234 *last = token_usage;
2235 }
2236 }
2237
2238 pub fn deny_tool_use(
2239 &mut self,
2240 tool_use_id: LanguageModelToolUseId,
2241 tool_name: Arc<str>,
2242 window: Option<AnyWindowHandle>,
2243 cx: &mut Context<Self>,
2244 ) {
2245 let err = Err(anyhow::anyhow!(
2246 "Permission to run tool action denied by user"
2247 ));
2248
2249 self.tool_use
2250 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2251 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2252 }
2253}
2254
2255#[derive(Debug, Clone, Error)]
2256pub enum ThreadError {
2257 #[error("Payment required")]
2258 PaymentRequired,
2259 #[error("Max monthly spend reached")]
2260 MaxMonthlySpendReached,
2261 #[error("Model request limit reached")]
2262 ModelRequestLimitReached { plan: Plan },
2263 #[error("Message {header}: {message}")]
2264 Message {
2265 header: SharedString,
2266 message: SharedString,
2267 },
2268}
2269
2270#[derive(Debug, Clone)]
2271pub enum ThreadEvent {
2272 ShowError(ThreadError),
2273 UsageUpdated(RequestUsage),
2274 StreamedCompletion,
2275 ReceivedTextChunk,
2276 StreamedAssistantText(MessageId, String),
2277 StreamedAssistantThinking(MessageId, String),
2278 StreamedToolUse {
2279 tool_use_id: LanguageModelToolUseId,
2280 ui_text: Arc<str>,
2281 input: serde_json::Value,
2282 },
2283 InvalidToolInput {
2284 tool_use_id: LanguageModelToolUseId,
2285 ui_text: Arc<str>,
2286 invalid_input_json: Arc<str>,
2287 },
2288 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2289 MessageAdded(MessageId),
2290 MessageEdited(MessageId),
2291 MessageDeleted(MessageId),
2292 SummaryGenerated,
2293 SummaryChanged,
2294 UsePendingTools {
2295 tool_uses: Vec<PendingToolUse>,
2296 },
2297 ToolFinished {
2298 #[allow(unused)]
2299 tool_use_id: LanguageModelToolUseId,
2300 /// The pending tool use that corresponds to this tool.
2301 pending_tool_use: Option<PendingToolUse>,
2302 },
2303 CheckpointChanged,
2304 ToolConfirmationNeeded,
2305}
2306
2307impl EventEmitter<ThreadEvent> for Thread {}
2308
2309struct PendingCompletion {
2310 id: usize,
2311 _task: Task<()>,
2312}
2313
2314#[cfg(test)]
2315mod tests {
2316 use super::*;
2317 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2318 use assistant_settings::AssistantSettings;
2319 use context_server::ContextServerSettings;
2320 use editor::EditorSettings;
2321 use gpui::TestAppContext;
2322 use project::{FakeFs, Project};
2323 use prompt_store::PromptBuilder;
2324 use serde_json::json;
2325 use settings::{Settings, SettingsStore};
2326 use std::sync::Arc;
2327 use theme::ThemeSettings;
2328 use util::path;
2329 use workspace::Workspace;
2330
2331 #[gpui::test]
2332 async fn test_message_with_context(cx: &mut TestAppContext) {
2333 init_test_settings(cx);
2334
2335 let project = create_test_project(
2336 cx,
2337 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2338 )
2339 .await;
2340
2341 let (_workspace, _thread_store, thread, context_store) =
2342 setup_test_environment(cx, project.clone()).await;
2343
2344 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2345 .await
2346 .unwrap();
2347
2348 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2349 let loaded_context = cx
2350 .update(|cx| load_context(vec![context], &project, &None, cx))
2351 .await;
2352
2353 // Insert user message with context
2354 let message_id = thread.update(cx, |thread, cx| {
2355 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2356 });
2357
2358 // Check content and context in message object
2359 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2360
2361 // Use different path format strings based on platform for the test
2362 #[cfg(windows)]
2363 let path_part = r"test\code.rs";
2364 #[cfg(not(windows))]
2365 let path_part = "test/code.rs";
2366
2367 let expected_context = format!(
2368 r#"
2369<context>
2370The following items were attached by the user. You don't need to use other tools to read them.
2371
2372<files>
2373```rs {path_part}
2374fn main() {{
2375 println!("Hello, world!");
2376}}
2377```
2378</files>
2379</context>
2380"#
2381 );
2382
2383 assert_eq!(message.role, Role::User);
2384 assert_eq!(message.segments.len(), 1);
2385 assert_eq!(
2386 message.segments[0],
2387 MessageSegment::Text("Please explain this code".to_string())
2388 );
2389 assert_eq!(message.loaded_context.text, expected_context);
2390
2391 // Check message in request
2392 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2393
2394 assert_eq!(request.messages.len(), 2);
2395 let expected_full_message = format!("{}Please explain this code", expected_context);
2396 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2397 }
2398
2399 #[gpui::test]
2400 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2401 init_test_settings(cx);
2402
2403 let project = create_test_project(
2404 cx,
2405 json!({
2406 "file1.rs": "fn function1() {}\n",
2407 "file2.rs": "fn function2() {}\n",
2408 "file3.rs": "fn function3() {}\n",
2409 }),
2410 )
2411 .await;
2412
2413 let (_, _thread_store, thread, context_store) =
2414 setup_test_environment(cx, project.clone()).await;
2415
2416 // First message with context 1
2417 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2418 .await
2419 .unwrap();
2420 let new_contexts = context_store.update(cx, |store, cx| {
2421 store.new_context_for_thread(thread.read(cx))
2422 });
2423 assert_eq!(new_contexts.len(), 1);
2424 let loaded_context = cx
2425 .update(|cx| load_context(new_contexts, &project, &None, cx))
2426 .await;
2427 let message1_id = thread.update(cx, |thread, cx| {
2428 thread.insert_user_message("Message 1", loaded_context, None, cx)
2429 });
2430
2431 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2432 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2433 .await
2434 .unwrap();
2435 let new_contexts = context_store.update(cx, |store, cx| {
2436 store.new_context_for_thread(thread.read(cx))
2437 });
2438 assert_eq!(new_contexts.len(), 1);
2439 let loaded_context = cx
2440 .update(|cx| load_context(new_contexts, &project, &None, cx))
2441 .await;
2442 let message2_id = thread.update(cx, |thread, cx| {
2443 thread.insert_user_message("Message 2", loaded_context, None, cx)
2444 });
2445
2446 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2447 //
2448 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2449 .await
2450 .unwrap();
2451 let new_contexts = context_store.update(cx, |store, cx| {
2452 store.new_context_for_thread(thread.read(cx))
2453 });
2454 assert_eq!(new_contexts.len(), 1);
2455 let loaded_context = cx
2456 .update(|cx| load_context(new_contexts, &project, &None, cx))
2457 .await;
2458 let message3_id = thread.update(cx, |thread, cx| {
2459 thread.insert_user_message("Message 3", loaded_context, None, cx)
2460 });
2461
2462 // Check what contexts are included in each message
2463 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2464 (
2465 thread.message(message1_id).unwrap().clone(),
2466 thread.message(message2_id).unwrap().clone(),
2467 thread.message(message3_id).unwrap().clone(),
2468 )
2469 });
2470
2471 // First message should include context 1
2472 assert!(message1.loaded_context.text.contains("file1.rs"));
2473
2474 // Second message should include only context 2 (not 1)
2475 assert!(!message2.loaded_context.text.contains("file1.rs"));
2476 assert!(message2.loaded_context.text.contains("file2.rs"));
2477
2478 // Third message should include only context 3 (not 1 or 2)
2479 assert!(!message3.loaded_context.text.contains("file1.rs"));
2480 assert!(!message3.loaded_context.text.contains("file2.rs"));
2481 assert!(message3.loaded_context.text.contains("file3.rs"));
2482
2483 // Check entire request to make sure all contexts are properly included
2484 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2485
2486 // The request should contain all 3 messages
2487 assert_eq!(request.messages.len(), 4);
2488
2489 // Check that the contexts are properly formatted in each message
2490 assert!(request.messages[1].string_contents().contains("file1.rs"));
2491 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2492 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2493
2494 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2495 assert!(request.messages[2].string_contents().contains("file2.rs"));
2496 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2497
2498 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2499 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2500 assert!(request.messages[3].string_contents().contains("file3.rs"));
2501 }
2502
2503 #[gpui::test]
2504 async fn test_message_without_files(cx: &mut TestAppContext) {
2505 init_test_settings(cx);
2506
2507 let project = create_test_project(
2508 cx,
2509 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2510 )
2511 .await;
2512
2513 let (_, _thread_store, thread, _context_store) =
2514 setup_test_environment(cx, project.clone()).await;
2515
2516 // Insert user message without any context (empty context vector)
2517 let message_id = thread.update(cx, |thread, cx| {
2518 thread.insert_user_message(
2519 "What is the best way to learn Rust?",
2520 ContextLoadResult::default(),
2521 None,
2522 cx,
2523 )
2524 });
2525
2526 // Check content and context in message object
2527 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2528
2529 // Context should be empty when no files are included
2530 assert_eq!(message.role, Role::User);
2531 assert_eq!(message.segments.len(), 1);
2532 assert_eq!(
2533 message.segments[0],
2534 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2535 );
2536 assert_eq!(message.loaded_context.text, "");
2537
2538 // Check message in request
2539 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2540
2541 assert_eq!(request.messages.len(), 2);
2542 assert_eq!(
2543 request.messages[1].string_contents(),
2544 "What is the best way to learn Rust?"
2545 );
2546
2547 // Add second message, also without context
2548 let message2_id = thread.update(cx, |thread, cx| {
2549 thread.insert_user_message(
2550 "Are there any good books?",
2551 ContextLoadResult::default(),
2552 None,
2553 cx,
2554 )
2555 });
2556
2557 let message2 =
2558 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2559 assert_eq!(message2.loaded_context.text, "");
2560
2561 // Check that both messages appear in the request
2562 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2563
2564 assert_eq!(request.messages.len(), 3);
2565 assert_eq!(
2566 request.messages[1].string_contents(),
2567 "What is the best way to learn Rust?"
2568 );
2569 assert_eq!(
2570 request.messages[2].string_contents(),
2571 "Are there any good books?"
2572 );
2573 }
2574
2575 #[gpui::test]
2576 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2577 init_test_settings(cx);
2578
2579 let project = create_test_project(
2580 cx,
2581 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2582 )
2583 .await;
2584
2585 let (_workspace, _thread_store, thread, context_store) =
2586 setup_test_environment(cx, project.clone()).await;
2587
2588 // Open buffer and add it to context
2589 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2590 .await
2591 .unwrap();
2592
2593 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2594 let loaded_context = cx
2595 .update(|cx| load_context(vec![context], &project, &None, cx))
2596 .await;
2597
2598 // Insert user message with the buffer as context
2599 thread.update(cx, |thread, cx| {
2600 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2601 });
2602
2603 // Create a request and check that it doesn't have a stale buffer warning yet
2604 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2605
2606 // Make sure we don't have a stale file warning yet
2607 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2608 msg.string_contents()
2609 .contains("These files changed since last read:")
2610 });
2611 assert!(
2612 !has_stale_warning,
2613 "Should not have stale buffer warning before buffer is modified"
2614 );
2615
2616 // Modify the buffer
2617 buffer.update(cx, |buffer, cx| {
2618 // Find a position at the end of line 1
2619 buffer.edit(
2620 [(1..1, "\n println!(\"Added a new line\");\n")],
2621 None,
2622 cx,
2623 );
2624 });
2625
2626 // Insert another user message without context
2627 thread.update(cx, |thread, cx| {
2628 thread.insert_user_message(
2629 "What does the code do now?",
2630 ContextLoadResult::default(),
2631 None,
2632 cx,
2633 )
2634 });
2635
2636 // Create a new request and check for the stale buffer warning
2637 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2638
2639 // We should have a stale file warning as the last message
2640 let last_message = new_request
2641 .messages
2642 .last()
2643 .expect("Request should have messages");
2644
2645 // The last message should be the stale buffer notification
2646 assert_eq!(last_message.role, Role::User);
2647
2648 // Check the exact content of the message
2649 let expected_content = "These files changed since last read:\n- code.rs\n";
2650 assert_eq!(
2651 last_message.string_contents(),
2652 expected_content,
2653 "Last message should be exactly the stale buffer notification"
2654 );
2655 }
2656
2657 fn init_test_settings(cx: &mut TestAppContext) {
2658 cx.update(|cx| {
2659 let settings_store = SettingsStore::test(cx);
2660 cx.set_global(settings_store);
2661 language::init(cx);
2662 Project::init_settings(cx);
2663 AssistantSettings::register(cx);
2664 prompt_store::init(cx);
2665 thread_store::init(cx);
2666 workspace::init_settings(cx);
2667 ThemeSettings::register(cx);
2668 ContextServerSettings::register(cx);
2669 EditorSettings::register(cx);
2670 });
2671 }
2672
2673 // Helper to create a test project with test files
2674 async fn create_test_project(
2675 cx: &mut TestAppContext,
2676 files: serde_json::Value,
2677 ) -> Entity<Project> {
2678 let fs = FakeFs::new(cx.executor());
2679 fs.insert_tree(path!("/test"), files).await;
2680 Project::test(fs, [path!("/test").as_ref()], cx).await
2681 }
2682
2683 async fn setup_test_environment(
2684 cx: &mut TestAppContext,
2685 project: Entity<Project>,
2686 ) -> (
2687 Entity<Workspace>,
2688 Entity<ThreadStore>,
2689 Entity<Thread>,
2690 Entity<ContextStore>,
2691 ) {
2692 let (workspace, cx) =
2693 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2694
2695 let thread_store = cx
2696 .update(|_, cx| {
2697 ThreadStore::load(
2698 project.clone(),
2699 cx.new(|_| ToolWorkingSet::default()),
2700 None,
2701 Arc::new(PromptBuilder::new(None).unwrap()),
2702 cx,
2703 )
2704 })
2705 .await
2706 .unwrap();
2707
2708 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2709 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2710
2711 (workspace, thread_store, thread, context_store)
2712 }
2713
2714 async fn add_file_to_context(
2715 project: &Entity<Project>,
2716 context_store: &Entity<ContextStore>,
2717 path: &str,
2718 cx: &mut TestAppContext,
2719 ) -> Result<Entity<language::Buffer>> {
2720 let buffer_path = project
2721 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2722 .unwrap();
2723
2724 let buffer = project
2725 .update(cx, |project, cx| {
2726 project.open_buffer(buffer_path.clone(), cx)
2727 })
2728 .await
2729 .unwrap();
2730
2731 context_store.update(cx, |context_store, cx| {
2732 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2733 });
2734
2735 Ok(buffer)
2736 }
2737}