1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100}
101
102impl Message {
103 /// Returns whether the message contains any meaningful text that should be displayed
104 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
105 pub fn should_display_content(&self) -> bool {
106 self.segments.iter().all(|segment| segment.should_display())
107 }
108
109 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
110 if let Some(MessageSegment::Thinking {
111 text: segment,
112 signature: current_signature,
113 }) = self.segments.last_mut()
114 {
115 if let Some(signature) = signature {
116 *current_signature = Some(signature);
117 }
118 segment.push_str(text);
119 } else {
120 self.segments.push(MessageSegment::Thinking {
121 text: text.to_string(),
122 signature,
123 });
124 }
125 }
126
127 pub fn push_text(&mut self, text: &str) {
128 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
129 segment.push_str(text);
130 } else {
131 self.segments.push(MessageSegment::Text(text.to_string()));
132 }
133 }
134
135 pub fn to_string(&self) -> String {
136 let mut result = String::new();
137
138 if !self.context.is_empty() {
139 result.push_str(&self.context);
140 }
141
142 for segment in &self.segments {
143 match segment {
144 MessageSegment::Text(text) => result.push_str(text),
145 MessageSegment::Thinking { text, .. } => {
146 result.push_str("<think>\n");
147 result.push_str(text);
148 result.push_str("\n</think>");
149 }
150 MessageSegment::RedactedThinking(_) => {}
151 }
152 }
153
154 result
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MessageSegment {
160 Text(String),
161 Thinking {
162 text: String,
163 signature: Option<String>,
164 },
165 RedactedThinking(Vec<u8>),
166}
167
168impl MessageSegment {
169 pub fn should_display(&self) -> bool {
170 match self {
171 Self::Text(text) => text.is_empty(),
172 Self::Thinking { text, .. } => text.is_empty(),
173 Self::RedactedThinking(_) => false,
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ProjectSnapshot {
180 pub worktree_snapshots: Vec<WorktreeSnapshot>,
181 pub unsaved_buffer_paths: Vec<String>,
182 pub timestamp: DateTime<Utc>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct WorktreeSnapshot {
187 pub worktree_path: String,
188 pub git_state: Option<GitState>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GitState {
193 pub remote_url: Option<String>,
194 pub head_sha: Option<String>,
195 pub current_branch: Option<String>,
196 pub diff: Option<String>,
197}
198
199#[derive(Clone)]
200pub struct ThreadCheckpoint {
201 message_id: MessageId,
202 git_checkpoint: GitStoreCheckpoint,
203}
204
205#[derive(Copy, Clone, Debug, PartialEq, Eq)]
206pub enum ThreadFeedback {
207 Positive,
208 Negative,
209}
210
211pub enum LastRestoreCheckpoint {
212 Pending {
213 message_id: MessageId,
214 },
215 Error {
216 message_id: MessageId,
217 error: String,
218 },
219}
220
221impl LastRestoreCheckpoint {
222 pub fn message_id(&self) -> MessageId {
223 match self {
224 LastRestoreCheckpoint::Pending { message_id } => *message_id,
225 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
226 }
227 }
228}
229
230#[derive(Clone, Debug, Default, Serialize, Deserialize)]
231pub enum DetailedSummaryState {
232 #[default]
233 NotGenerated,
234 Generating {
235 message_id: MessageId,
236 },
237 Generated {
238 text: SharedString,
239 message_id: MessageId,
240 },
241}
242
243#[derive(Default)]
244pub struct TotalTokenUsage {
245 pub total: usize,
246 pub max: usize,
247}
248
249impl TotalTokenUsage {
250 pub fn ratio(&self) -> TokenUsageRatio {
251 #[cfg(debug_assertions)]
252 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
253 .unwrap_or("0.8".to_string())
254 .parse()
255 .unwrap();
256 #[cfg(not(debug_assertions))]
257 let warning_threshold: f32 = 0.8;
258
259 if self.total >= self.max {
260 TokenUsageRatio::Exceeded
261 } else if self.total as f32 / self.max as f32 >= warning_threshold {
262 TokenUsageRatio::Warning
263 } else {
264 TokenUsageRatio::Normal
265 }
266 }
267
268 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
269 TotalTokenUsage {
270 total: self.total + tokens,
271 max: self.max,
272 }
273 }
274}
275
276#[derive(Debug, Default, PartialEq, Eq)]
277pub enum TokenUsageRatio {
278 #[default]
279 Normal,
280 Warning,
281 Exceeded,
282}
283
284/// A thread of conversation with the LLM.
285pub struct Thread {
286 id: ThreadId,
287 updated_at: DateTime<Utc>,
288 summary: Option<SharedString>,
289 pending_summary: Task<Option<()>>,
290 detailed_summary_state: DetailedSummaryState,
291 messages: Vec<Message>,
292 next_message_id: MessageId,
293 last_prompt_id: PromptId,
294 context: BTreeMap<ContextId, AssistantContext>,
295 context_by_message: HashMap<MessageId, Vec<ContextId>>,
296 project_context: SharedProjectContext,
297 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
298 completion_count: usize,
299 pending_completions: Vec<PendingCompletion>,
300 project: Entity<Project>,
301 prompt_builder: Arc<PromptBuilder>,
302 tools: Entity<ToolWorkingSet>,
303 tool_use: ToolUseState,
304 action_log: Entity<ActionLog>,
305 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
306 pending_checkpoint: Option<ThreadCheckpoint>,
307 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
308 request_token_usage: Vec<TokenUsage>,
309 cumulative_token_usage: TokenUsage,
310 exceeded_window_error: Option<ExceededWindowError>,
311 feedback: Option<ThreadFeedback>,
312 message_feedback: HashMap<MessageId, ThreadFeedback>,
313 last_auto_capture_at: Option<Instant>,
314 request_callback: Option<
315 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
316 >,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct ExceededWindowError {
321 /// Model used when last message exceeded context window
322 model_id: LanguageModelId,
323 /// Token count including last message
324 token_count: usize,
325}
326
327impl Thread {
328 pub fn new(
329 project: Entity<Project>,
330 tools: Entity<ToolWorkingSet>,
331 prompt_builder: Arc<PromptBuilder>,
332 system_prompt: SharedProjectContext,
333 cx: &mut Context<Self>,
334 ) -> Self {
335 Self {
336 id: ThreadId::new(),
337 updated_at: Utc::now(),
338 summary: None,
339 pending_summary: Task::ready(None),
340 detailed_summary_state: DetailedSummaryState::NotGenerated,
341 messages: Vec::new(),
342 next_message_id: MessageId(0),
343 last_prompt_id: PromptId::new(),
344 context: BTreeMap::default(),
345 context_by_message: HashMap::default(),
346 project_context: system_prompt,
347 checkpoints_by_message: HashMap::default(),
348 completion_count: 0,
349 pending_completions: Vec::new(),
350 project: project.clone(),
351 prompt_builder,
352 tools: tools.clone(),
353 last_restore_checkpoint: None,
354 pending_checkpoint: None,
355 tool_use: ToolUseState::new(tools.clone()),
356 action_log: cx.new(|_| ActionLog::new(project.clone())),
357 initial_project_snapshot: {
358 let project_snapshot = Self::project_snapshot(project, cx);
359 cx.foreground_executor()
360 .spawn(async move { Some(project_snapshot.await) })
361 .shared()
362 },
363 request_token_usage: Vec::new(),
364 cumulative_token_usage: TokenUsage::default(),
365 exceeded_window_error: None,
366 feedback: None,
367 message_feedback: HashMap::default(),
368 last_auto_capture_at: None,
369 request_callback: None,
370 }
371 }
372
373 pub fn deserialize(
374 id: ThreadId,
375 serialized: SerializedThread,
376 project: Entity<Project>,
377 tools: Entity<ToolWorkingSet>,
378 prompt_builder: Arc<PromptBuilder>,
379 project_context: SharedProjectContext,
380 cx: &mut Context<Self>,
381 ) -> Self {
382 let next_message_id = MessageId(
383 serialized
384 .messages
385 .last()
386 .map(|message| message.id.0 + 1)
387 .unwrap_or(0),
388 );
389 let tool_use =
390 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
391
392 Self {
393 id,
394 updated_at: serialized.updated_at,
395 summary: Some(serialized.summary),
396 pending_summary: Task::ready(None),
397 detailed_summary_state: serialized.detailed_summary_state,
398 messages: serialized
399 .messages
400 .into_iter()
401 .map(|message| Message {
402 id: message.id,
403 role: message.role,
404 segments: message
405 .segments
406 .into_iter()
407 .map(|segment| match segment {
408 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
409 SerializedMessageSegment::Thinking { text, signature } => {
410 MessageSegment::Thinking { text, signature }
411 }
412 SerializedMessageSegment::RedactedThinking { data } => {
413 MessageSegment::RedactedThinking(data)
414 }
415 })
416 .collect(),
417 context: message.context,
418 })
419 .collect(),
420 next_message_id,
421 last_prompt_id: PromptId::new(),
422 context: BTreeMap::default(),
423 context_by_message: HashMap::default(),
424 project_context,
425 checkpoints_by_message: HashMap::default(),
426 completion_count: 0,
427 pending_completions: Vec::new(),
428 last_restore_checkpoint: None,
429 pending_checkpoint: None,
430 project: project.clone(),
431 prompt_builder,
432 tools,
433 tool_use,
434 action_log: cx.new(|_| ActionLog::new(project)),
435 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
436 request_token_usage: serialized.request_token_usage,
437 cumulative_token_usage: serialized.cumulative_token_usage,
438 exceeded_window_error: None,
439 feedback: None,
440 message_feedback: HashMap::default(),
441 last_auto_capture_at: None,
442 request_callback: None,
443 }
444 }
445
446 pub fn set_request_callback(
447 &mut self,
448 callback: impl 'static
449 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
450 ) {
451 self.request_callback = Some(Box::new(callback));
452 }
453
454 pub fn id(&self) -> &ThreadId {
455 &self.id
456 }
457
458 pub fn is_empty(&self) -> bool {
459 self.messages.is_empty()
460 }
461
462 pub fn updated_at(&self) -> DateTime<Utc> {
463 self.updated_at
464 }
465
466 pub fn touch_updated_at(&mut self) {
467 self.updated_at = Utc::now();
468 }
469
470 pub fn advance_prompt_id(&mut self) {
471 self.last_prompt_id = PromptId::new();
472 }
473
474 pub fn summary(&self) -> Option<SharedString> {
475 self.summary.clone()
476 }
477
478 pub fn project_context(&self) -> SharedProjectContext {
479 self.project_context.clone()
480 }
481
482 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
483
484 pub fn summary_or_default(&self) -> SharedString {
485 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
486 }
487
488 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
489 let Some(current_summary) = &self.summary else {
490 // Don't allow setting summary until generated
491 return;
492 };
493
494 let mut new_summary = new_summary.into();
495
496 if new_summary.is_empty() {
497 new_summary = Self::DEFAULT_SUMMARY;
498 }
499
500 if current_summary != &new_summary {
501 self.summary = Some(new_summary);
502 cx.emit(ThreadEvent::SummaryChanged);
503 }
504 }
505
506 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
507 self.latest_detailed_summary()
508 .unwrap_or_else(|| self.text().into())
509 }
510
511 fn latest_detailed_summary(&self) -> Option<SharedString> {
512 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
513 Some(text.clone())
514 } else {
515 None
516 }
517 }
518
519 pub fn message(&self, id: MessageId) -> Option<&Message> {
520 self.messages.iter().find(|message| message.id == id)
521 }
522
523 pub fn messages(&self) -> impl Iterator<Item = &Message> {
524 self.messages.iter()
525 }
526
527 pub fn is_generating(&self) -> bool {
528 !self.pending_completions.is_empty() || !self.all_tools_finished()
529 }
530
531 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
532 &self.tools
533 }
534
535 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
536 self.tool_use
537 .pending_tool_uses()
538 .into_iter()
539 .find(|tool_use| &tool_use.id == id)
540 }
541
542 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
543 self.tool_use
544 .pending_tool_uses()
545 .into_iter()
546 .filter(|tool_use| tool_use.status.needs_confirmation())
547 }
548
549 pub fn has_pending_tool_uses(&self) -> bool {
550 !self.tool_use.pending_tool_uses().is_empty()
551 }
552
553 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
554 self.checkpoints_by_message.get(&id).cloned()
555 }
556
557 pub fn restore_checkpoint(
558 &mut self,
559 checkpoint: ThreadCheckpoint,
560 cx: &mut Context<Self>,
561 ) -> Task<Result<()>> {
562 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
563 message_id: checkpoint.message_id,
564 });
565 cx.emit(ThreadEvent::CheckpointChanged);
566 cx.notify();
567
568 let git_store = self.project().read(cx).git_store().clone();
569 let restore = git_store.update(cx, |git_store, cx| {
570 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
571 });
572
573 cx.spawn(async move |this, cx| {
574 let result = restore.await;
575 this.update(cx, |this, cx| {
576 if let Err(err) = result.as_ref() {
577 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
578 message_id: checkpoint.message_id,
579 error: err.to_string(),
580 });
581 } else {
582 this.truncate(checkpoint.message_id, cx);
583 this.last_restore_checkpoint = None;
584 }
585 this.pending_checkpoint = None;
586 cx.emit(ThreadEvent::CheckpointChanged);
587 cx.notify();
588 })?;
589 result
590 })
591 }
592
593 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
594 let pending_checkpoint = if self.is_generating() {
595 return;
596 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
597 checkpoint
598 } else {
599 return;
600 };
601
602 let git_store = self.project.read(cx).git_store().clone();
603 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
604 cx.spawn(async move |this, cx| match final_checkpoint.await {
605 Ok(final_checkpoint) => {
606 let equal = git_store
607 .update(cx, |store, cx| {
608 store.compare_checkpoints(
609 pending_checkpoint.git_checkpoint.clone(),
610 final_checkpoint.clone(),
611 cx,
612 )
613 })?
614 .await
615 .unwrap_or(false);
616
617 if equal {
618 git_store
619 .update(cx, |store, cx| {
620 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
621 })?
622 .detach();
623 } else {
624 this.update(cx, |this, cx| {
625 this.insert_checkpoint(pending_checkpoint, cx)
626 })?;
627 }
628
629 git_store
630 .update(cx, |store, cx| {
631 store.delete_checkpoint(final_checkpoint, cx)
632 })?
633 .detach();
634
635 Ok(())
636 }
637 Err(_) => this.update(cx, |this, cx| {
638 this.insert_checkpoint(pending_checkpoint, cx)
639 }),
640 })
641 .detach();
642 }
643
644 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
645 self.checkpoints_by_message
646 .insert(checkpoint.message_id, checkpoint);
647 cx.emit(ThreadEvent::CheckpointChanged);
648 cx.notify();
649 }
650
651 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
652 self.last_restore_checkpoint.as_ref()
653 }
654
655 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
656 let Some(message_ix) = self
657 .messages
658 .iter()
659 .rposition(|message| message.id == message_id)
660 else {
661 return;
662 };
663 for deleted_message in self.messages.drain(message_ix..) {
664 self.context_by_message.remove(&deleted_message.id);
665 self.checkpoints_by_message.remove(&deleted_message.id);
666 }
667 cx.notify();
668 }
669
670 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
671 self.context_by_message
672 .get(&id)
673 .into_iter()
674 .flat_map(|context| {
675 context
676 .iter()
677 .filter_map(|context_id| self.context.get(&context_id))
678 })
679 }
680
681 /// Returns whether all of the tool uses have finished running.
682 pub fn all_tools_finished(&self) -> bool {
683 // If the only pending tool uses left are the ones with errors, then
684 // that means that we've finished running all of the pending tools.
685 self.tool_use
686 .pending_tool_uses()
687 .iter()
688 .all(|tool_use| tool_use.status.is_error())
689 }
690
691 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
692 self.tool_use.tool_uses_for_message(id, cx)
693 }
694
695 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
696 self.tool_use.tool_results_for_message(id)
697 }
698
699 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
700 self.tool_use.tool_result(id)
701 }
702
703 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
704 Some(&self.tool_use.tool_result(id)?.content)
705 }
706
707 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
708 self.tool_use.tool_result_card(id).cloned()
709 }
710
711 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
712 self.tool_use.message_has_tool_results(message_id)
713 }
714
715 /// Filter out contexts that have already been included in previous messages
716 pub fn filter_new_context<'a>(
717 &self,
718 context: impl Iterator<Item = &'a AssistantContext>,
719 ) -> impl Iterator<Item = &'a AssistantContext> {
720 context.filter(|ctx| self.is_context_new(ctx))
721 }
722
723 fn is_context_new(&self, context: &AssistantContext) -> bool {
724 !self.context.contains_key(&context.id())
725 }
726
727 pub fn insert_user_message(
728 &mut self,
729 text: impl Into<String>,
730 context: Vec<AssistantContext>,
731 git_checkpoint: Option<GitStoreCheckpoint>,
732 cx: &mut Context<Self>,
733 ) -> MessageId {
734 let text = text.into();
735
736 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
737
738 let new_context: Vec<_> = context
739 .into_iter()
740 .filter(|ctx| self.is_context_new(ctx))
741 .collect();
742
743 if !new_context.is_empty() {
744 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
745 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
746 message.context = context_string;
747 }
748 }
749
750 self.action_log.update(cx, |log, cx| {
751 // Track all buffers added as context
752 for ctx in &new_context {
753 match ctx {
754 AssistantContext::File(file_ctx) => {
755 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
756 }
757 AssistantContext::Directory(dir_ctx) => {
758 for context_buffer in &dir_ctx.context_buffers {
759 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
760 }
761 }
762 AssistantContext::Symbol(symbol_ctx) => {
763 log.buffer_added_as_context(
764 symbol_ctx.context_symbol.buffer.clone(),
765 cx,
766 );
767 }
768 AssistantContext::Excerpt(excerpt_context) => {
769 log.buffer_added_as_context(
770 excerpt_context.context_buffer.buffer.clone(),
771 cx,
772 );
773 }
774 AssistantContext::FetchedUrl(_)
775 | AssistantContext::Thread(_)
776 | AssistantContext::Rules(_) => {}
777 }
778 }
779 });
780 }
781
782 let context_ids = new_context
783 .iter()
784 .map(|context| context.id())
785 .collect::<Vec<_>>();
786 self.context.extend(
787 new_context
788 .into_iter()
789 .map(|context| (context.id(), context)),
790 );
791 self.context_by_message.insert(message_id, context_ids);
792
793 if let Some(git_checkpoint) = git_checkpoint {
794 self.pending_checkpoint = Some(ThreadCheckpoint {
795 message_id,
796 git_checkpoint,
797 });
798 }
799
800 self.auto_capture_telemetry(cx);
801
802 message_id
803 }
804
805 pub fn insert_message(
806 &mut self,
807 role: Role,
808 segments: Vec<MessageSegment>,
809 cx: &mut Context<Self>,
810 ) -> MessageId {
811 let id = self.next_message_id.post_inc();
812 self.messages.push(Message {
813 id,
814 role,
815 segments,
816 context: String::new(),
817 });
818 self.touch_updated_at();
819 cx.emit(ThreadEvent::MessageAdded(id));
820 id
821 }
822
823 pub fn edit_message(
824 &mut self,
825 id: MessageId,
826 new_role: Role,
827 new_segments: Vec<MessageSegment>,
828 cx: &mut Context<Self>,
829 ) -> bool {
830 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
831 return false;
832 };
833 message.role = new_role;
834 message.segments = new_segments;
835 self.touch_updated_at();
836 cx.emit(ThreadEvent::MessageEdited(id));
837 true
838 }
839
840 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
841 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
842 return false;
843 };
844 self.messages.remove(index);
845 self.context_by_message.remove(&id);
846 self.touch_updated_at();
847 cx.emit(ThreadEvent::MessageDeleted(id));
848 true
849 }
850
851 /// Returns the representation of this [`Thread`] in a textual form.
852 ///
853 /// This is the representation we use when attaching a thread as context to another thread.
854 pub fn text(&self) -> String {
855 let mut text = String::new();
856
857 for message in &self.messages {
858 text.push_str(match message.role {
859 language_model::Role::User => "User:",
860 language_model::Role::Assistant => "Assistant:",
861 language_model::Role::System => "System:",
862 });
863 text.push('\n');
864
865 for segment in &message.segments {
866 match segment {
867 MessageSegment::Text(content) => text.push_str(content),
868 MessageSegment::Thinking { text: content, .. } => {
869 text.push_str(&format!("<think>{}</think>", content))
870 }
871 MessageSegment::RedactedThinking(_) => {}
872 }
873 }
874 text.push('\n');
875 }
876
877 text
878 }
879
880 /// Serializes this thread into a format for storage or telemetry.
881 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
882 let initial_project_snapshot = self.initial_project_snapshot.clone();
883 cx.spawn(async move |this, cx| {
884 let initial_project_snapshot = initial_project_snapshot.await;
885 this.read_with(cx, |this, cx| SerializedThread {
886 version: SerializedThread::VERSION.to_string(),
887 summary: this.summary_or_default(),
888 updated_at: this.updated_at(),
889 messages: this
890 .messages()
891 .map(|message| SerializedMessage {
892 id: message.id,
893 role: message.role,
894 segments: message
895 .segments
896 .iter()
897 .map(|segment| match segment {
898 MessageSegment::Text(text) => {
899 SerializedMessageSegment::Text { text: text.clone() }
900 }
901 MessageSegment::Thinking { text, signature } => {
902 SerializedMessageSegment::Thinking {
903 text: text.clone(),
904 signature: signature.clone(),
905 }
906 }
907 MessageSegment::RedactedThinking(data) => {
908 SerializedMessageSegment::RedactedThinking {
909 data: data.clone(),
910 }
911 }
912 })
913 .collect(),
914 tool_uses: this
915 .tool_uses_for_message(message.id, cx)
916 .into_iter()
917 .map(|tool_use| SerializedToolUse {
918 id: tool_use.id,
919 name: tool_use.name,
920 input: tool_use.input,
921 })
922 .collect(),
923 tool_results: this
924 .tool_results_for_message(message.id)
925 .into_iter()
926 .map(|tool_result| SerializedToolResult {
927 tool_use_id: tool_result.tool_use_id.clone(),
928 is_error: tool_result.is_error,
929 content: tool_result.content.clone(),
930 })
931 .collect(),
932 context: message.context.clone(),
933 })
934 .collect(),
935 initial_project_snapshot,
936 cumulative_token_usage: this.cumulative_token_usage,
937 request_token_usage: this.request_token_usage.clone(),
938 detailed_summary_state: this.detailed_summary_state.clone(),
939 exceeded_window_error: this.exceeded_window_error.clone(),
940 })
941 })
942 }
943
944 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
945 let mut request = self.to_completion_request(cx);
946 if model.supports_tools() {
947 request.tools = {
948 let mut tools = Vec::new();
949 tools.extend(
950 self.tools()
951 .read(cx)
952 .enabled_tools(cx)
953 .into_iter()
954 .filter_map(|tool| {
955 // Skip tools that cannot be supported
956 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
957 Some(LanguageModelRequestTool {
958 name: tool.name(),
959 description: tool.description(),
960 input_schema,
961 })
962 }),
963 );
964
965 tools
966 };
967 }
968
969 self.stream_completion(request, model, cx);
970 }
971
972 pub fn used_tools_since_last_user_message(&self) -> bool {
973 for message in self.messages.iter().rev() {
974 if self.tool_use.message_has_tool_results(message.id) {
975 return true;
976 } else if message.role == Role::User {
977 return false;
978 }
979 }
980
981 false
982 }
983
984 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
985 let mut request = LanguageModelRequest {
986 thread_id: Some(self.id.to_string()),
987 prompt_id: Some(self.last_prompt_id.to_string()),
988 messages: vec![],
989 tools: Vec::new(),
990 stop: Vec::new(),
991 temperature: None,
992 };
993
994 if let Some(project_context) = self.project_context.borrow().as_ref() {
995 match self
996 .prompt_builder
997 .generate_assistant_system_prompt(project_context)
998 {
999 Err(err) => {
1000 let message = format!("{err:?}").into();
1001 log::error!("{message}");
1002 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1003 header: "Error generating system prompt".into(),
1004 message,
1005 }));
1006 }
1007 Ok(system_prompt) => {
1008 request.messages.push(LanguageModelRequestMessage {
1009 role: Role::System,
1010 content: vec![MessageContent::Text(system_prompt)],
1011 cache: true,
1012 });
1013 }
1014 }
1015 } else {
1016 let message = "Context for system prompt unexpectedly not ready.".into();
1017 log::error!("{message}");
1018 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1019 header: "Error generating system prompt".into(),
1020 message,
1021 }));
1022 }
1023
1024 for message in &self.messages {
1025 let mut request_message = LanguageModelRequestMessage {
1026 role: message.role,
1027 content: Vec::new(),
1028 cache: false,
1029 };
1030
1031 self.tool_use
1032 .attach_tool_results(message.id, &mut request_message);
1033
1034 if !message.context.is_empty() {
1035 request_message
1036 .content
1037 .push(MessageContent::Text(message.context.to_string()));
1038 }
1039
1040 for segment in &message.segments {
1041 match segment {
1042 MessageSegment::Text(text) => {
1043 if !text.is_empty() {
1044 request_message
1045 .content
1046 .push(MessageContent::Text(text.into()));
1047 }
1048 }
1049 MessageSegment::Thinking { text, signature } => {
1050 if !text.is_empty() {
1051 request_message.content.push(MessageContent::Thinking {
1052 text: text.into(),
1053 signature: signature.clone(),
1054 });
1055 }
1056 }
1057 MessageSegment::RedactedThinking(data) => {
1058 request_message
1059 .content
1060 .push(MessageContent::RedactedThinking(data.clone()));
1061 }
1062 };
1063 }
1064
1065 self.tool_use
1066 .attach_tool_uses(message.id, &mut request_message);
1067
1068 request.messages.push(request_message);
1069 }
1070
1071 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1072 if let Some(last) = request.messages.last_mut() {
1073 last.cache = true;
1074 }
1075
1076 self.attached_tracked_files_state(&mut request.messages, cx);
1077
1078 request
1079 }
1080
1081 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1082 let mut request = LanguageModelRequest {
1083 thread_id: None,
1084 prompt_id: None,
1085 messages: vec![],
1086 tools: Vec::new(),
1087 stop: Vec::new(),
1088 temperature: None,
1089 };
1090
1091 for message in &self.messages {
1092 let mut request_message = LanguageModelRequestMessage {
1093 role: message.role,
1094 content: Vec::new(),
1095 cache: false,
1096 };
1097
1098 // Skip tool results during summarization.
1099 if self.tool_use.message_has_tool_results(message.id) {
1100 continue;
1101 }
1102
1103 for segment in &message.segments {
1104 match segment {
1105 MessageSegment::Text(text) => request_message
1106 .content
1107 .push(MessageContent::Text(text.clone())),
1108 MessageSegment::Thinking { .. } => {}
1109 MessageSegment::RedactedThinking(_) => {}
1110 }
1111 }
1112
1113 if request_message.content.is_empty() {
1114 continue;
1115 }
1116
1117 request.messages.push(request_message);
1118 }
1119
1120 request.messages.push(LanguageModelRequestMessage {
1121 role: Role::User,
1122 content: vec![MessageContent::Text(added_user_message)],
1123 cache: false,
1124 });
1125
1126 request
1127 }
1128
1129 fn attached_tracked_files_state(
1130 &self,
1131 messages: &mut Vec<LanguageModelRequestMessage>,
1132 cx: &App,
1133 ) {
1134 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1135
1136 let mut stale_message = String::new();
1137
1138 let action_log = self.action_log.read(cx);
1139
1140 for stale_file in action_log.stale_buffers(cx) {
1141 let Some(file) = stale_file.read(cx).file() else {
1142 continue;
1143 };
1144
1145 if stale_message.is_empty() {
1146 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1147 }
1148
1149 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1150 }
1151
1152 let mut content = Vec::with_capacity(2);
1153
1154 if !stale_message.is_empty() {
1155 content.push(stale_message.into());
1156 }
1157
1158 if !content.is_empty() {
1159 let context_message = LanguageModelRequestMessage {
1160 role: Role::User,
1161 content,
1162 cache: false,
1163 };
1164
1165 messages.push(context_message);
1166 }
1167 }
1168
1169 pub fn stream_completion(
1170 &mut self,
1171 request: LanguageModelRequest,
1172 model: Arc<dyn LanguageModel>,
1173 cx: &mut Context<Self>,
1174 ) {
1175 let pending_completion_id = post_inc(&mut self.completion_count);
1176 let mut request_callback_parameters = if self.request_callback.is_some() {
1177 Some((request.clone(), Vec::new()))
1178 } else {
1179 None
1180 };
1181 let prompt_id = self.last_prompt_id.clone();
1182 let tool_use_metadata = ToolUseMetadata {
1183 model: model.clone(),
1184 thread_id: self.id.clone(),
1185 prompt_id: prompt_id.clone(),
1186 };
1187
1188 let task = cx.spawn(async move |thread, cx| {
1189 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1190 let initial_token_usage =
1191 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1192 let stream_completion = async {
1193 let (mut events, usage) = stream_completion_future.await?;
1194
1195 let mut stop_reason = StopReason::EndTurn;
1196 let mut current_token_usage = TokenUsage::default();
1197
1198 if let Some(usage) = usage {
1199 thread
1200 .update(cx, |_thread, cx| {
1201 cx.emit(ThreadEvent::UsageUpdated(usage));
1202 })
1203 .ok();
1204 }
1205
1206 while let Some(event) = events.next().await {
1207 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1208 response_events
1209 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1210 }
1211
1212 let event = event?;
1213
1214 thread.update(cx, |thread, cx| {
1215 match event {
1216 LanguageModelCompletionEvent::StartMessage { .. } => {
1217 thread.insert_message(
1218 Role::Assistant,
1219 vec![MessageSegment::Text(String::new())],
1220 cx,
1221 );
1222 }
1223 LanguageModelCompletionEvent::Stop(reason) => {
1224 stop_reason = reason;
1225 }
1226 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1227 thread.update_token_usage_at_last_message(token_usage);
1228 thread.cumulative_token_usage = thread.cumulative_token_usage
1229 + token_usage
1230 - current_token_usage;
1231 current_token_usage = token_usage;
1232 }
1233 LanguageModelCompletionEvent::Text(chunk) => {
1234 cx.emit(ThreadEvent::ReceivedTextChunk);
1235 if let Some(last_message) = thread.messages.last_mut() {
1236 if last_message.role == Role::Assistant {
1237 last_message.push_text(&chunk);
1238 cx.emit(ThreadEvent::StreamedAssistantText(
1239 last_message.id,
1240 chunk,
1241 ));
1242 } else {
1243 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1244 // of a new Assistant response.
1245 //
1246 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1247 // will result in duplicating the text of the chunk in the rendered Markdown.
1248 thread.insert_message(
1249 Role::Assistant,
1250 vec![MessageSegment::Text(chunk.to_string())],
1251 cx,
1252 );
1253 };
1254 }
1255 }
1256 LanguageModelCompletionEvent::Thinking {
1257 text: chunk,
1258 signature,
1259 } => {
1260 if let Some(last_message) = thread.messages.last_mut() {
1261 if last_message.role == Role::Assistant {
1262 last_message.push_thinking(&chunk, signature);
1263 cx.emit(ThreadEvent::StreamedAssistantThinking(
1264 last_message.id,
1265 chunk,
1266 ));
1267 } else {
1268 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1269 // of a new Assistant response.
1270 //
1271 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1272 // will result in duplicating the text of the chunk in the rendered Markdown.
1273 thread.insert_message(
1274 Role::Assistant,
1275 vec![MessageSegment::Thinking {
1276 text: chunk.to_string(),
1277 signature,
1278 }],
1279 cx,
1280 );
1281 };
1282 }
1283 }
1284 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1285 let last_assistant_message_id = thread
1286 .messages
1287 .iter_mut()
1288 .rfind(|message| message.role == Role::Assistant)
1289 .map(|message| message.id)
1290 .unwrap_or_else(|| {
1291 thread.insert_message(Role::Assistant, vec![], cx)
1292 });
1293
1294 let tool_use_id = tool_use.id.clone();
1295 let streamed_input = if tool_use.is_input_complete {
1296 None
1297 } else {
1298 Some((&tool_use.input).clone())
1299 };
1300
1301 let ui_text = thread.tool_use.request_tool_use(
1302 last_assistant_message_id,
1303 tool_use,
1304 tool_use_metadata.clone(),
1305 cx,
1306 );
1307
1308 if let Some(input) = streamed_input {
1309 cx.emit(ThreadEvent::StreamedToolUse {
1310 tool_use_id,
1311 ui_text,
1312 input,
1313 });
1314 }
1315 }
1316 }
1317
1318 thread.touch_updated_at();
1319 cx.emit(ThreadEvent::StreamedCompletion);
1320 cx.notify();
1321
1322 thread.auto_capture_telemetry(cx);
1323 })?;
1324
1325 smol::future::yield_now().await;
1326 }
1327
1328 thread.update(cx, |thread, cx| {
1329 thread
1330 .pending_completions
1331 .retain(|completion| completion.id != pending_completion_id);
1332
1333 // If there is a response without tool use, summarize the message. Otherwise,
1334 // allow two tool uses before summarizing.
1335 if thread.summary.is_none()
1336 && thread.messages.len() >= 2
1337 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1338 {
1339 thread.summarize(cx);
1340 }
1341 })?;
1342
1343 anyhow::Ok(stop_reason)
1344 };
1345
1346 let result = stream_completion.await;
1347
1348 thread
1349 .update(cx, |thread, cx| {
1350 thread.finalize_pending_checkpoint(cx);
1351 match result.as_ref() {
1352 Ok(stop_reason) => match stop_reason {
1353 StopReason::ToolUse => {
1354 let tool_uses = thread.use_pending_tools(cx);
1355 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1356 }
1357 StopReason::EndTurn => {}
1358 StopReason::MaxTokens => {}
1359 },
1360 Err(error) => {
1361 if error.is::<PaymentRequiredError>() {
1362 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1363 } else if error.is::<MaxMonthlySpendReachedError>() {
1364 cx.emit(ThreadEvent::ShowError(
1365 ThreadError::MaxMonthlySpendReached,
1366 ));
1367 } else if let Some(error) =
1368 error.downcast_ref::<ModelRequestLimitReachedError>()
1369 {
1370 cx.emit(ThreadEvent::ShowError(
1371 ThreadError::ModelRequestLimitReached { plan: error.plan },
1372 ));
1373 } else if let Some(known_error) =
1374 error.downcast_ref::<LanguageModelKnownError>()
1375 {
1376 match known_error {
1377 LanguageModelKnownError::ContextWindowLimitExceeded {
1378 tokens,
1379 } => {
1380 thread.exceeded_window_error = Some(ExceededWindowError {
1381 model_id: model.id(),
1382 token_count: *tokens,
1383 });
1384 cx.notify();
1385 }
1386 }
1387 } else {
1388 let error_message = error
1389 .chain()
1390 .map(|err| err.to_string())
1391 .collect::<Vec<_>>()
1392 .join("\n");
1393 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1394 header: "Error interacting with language model".into(),
1395 message: SharedString::from(error_message.clone()),
1396 }));
1397 }
1398
1399 thread.cancel_last_completion(cx);
1400 }
1401 }
1402 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1403
1404 if let Some((request_callback, (request, response_events))) = thread
1405 .request_callback
1406 .as_mut()
1407 .zip(request_callback_parameters.as_ref())
1408 {
1409 request_callback(request, response_events);
1410 }
1411
1412 thread.auto_capture_telemetry(cx);
1413
1414 if let Ok(initial_usage) = initial_token_usage {
1415 let usage = thread.cumulative_token_usage - initial_usage;
1416
1417 telemetry::event!(
1418 "Assistant Thread Completion",
1419 thread_id = thread.id().to_string(),
1420 prompt_id = prompt_id,
1421 model = model.telemetry_id(),
1422 model_provider = model.provider_id().to_string(),
1423 input_tokens = usage.input_tokens,
1424 output_tokens = usage.output_tokens,
1425 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1426 cache_read_input_tokens = usage.cache_read_input_tokens,
1427 );
1428 }
1429 })
1430 .ok();
1431 });
1432
1433 self.pending_completions.push(PendingCompletion {
1434 id: pending_completion_id,
1435 _task: task,
1436 });
1437 }
1438
1439 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1440 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1441 return;
1442 };
1443
1444 if !model.provider.is_authenticated(cx) {
1445 return;
1446 }
1447
1448 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1449 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1450 If the conversation is about a specific subject, include it in the title. \
1451 Be descriptive. DO NOT speak in the first person.";
1452
1453 let request = self.to_summarize_request(added_user_message.into());
1454
1455 self.pending_summary = cx.spawn(async move |this, cx| {
1456 async move {
1457 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1458 let (mut messages, usage) = stream.await?;
1459
1460 if let Some(usage) = usage {
1461 this.update(cx, |_thread, cx| {
1462 cx.emit(ThreadEvent::UsageUpdated(usage));
1463 })
1464 .ok();
1465 }
1466
1467 let mut new_summary = String::new();
1468 while let Some(message) = messages.stream.next().await {
1469 let text = message?;
1470 let mut lines = text.lines();
1471 new_summary.extend(lines.next());
1472
1473 // Stop if the LLM generated multiple lines.
1474 if lines.next().is_some() {
1475 break;
1476 }
1477 }
1478
1479 this.update(cx, |this, cx| {
1480 if !new_summary.is_empty() {
1481 this.summary = Some(new_summary.into());
1482 }
1483
1484 cx.emit(ThreadEvent::SummaryGenerated);
1485 })?;
1486
1487 anyhow::Ok(())
1488 }
1489 .log_err()
1490 .await
1491 });
1492 }
1493
1494 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1495 let last_message_id = self.messages.last().map(|message| message.id)?;
1496
1497 match &self.detailed_summary_state {
1498 DetailedSummaryState::Generating { message_id, .. }
1499 | DetailedSummaryState::Generated { message_id, .. }
1500 if *message_id == last_message_id =>
1501 {
1502 // Already up-to-date
1503 return None;
1504 }
1505 _ => {}
1506 }
1507
1508 let ConfiguredModel { model, provider } =
1509 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1510
1511 if !provider.is_authenticated(cx) {
1512 return None;
1513 }
1514
1515 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1516 1. A brief overview of what was discussed\n\
1517 2. Key facts or information discovered\n\
1518 3. Outcomes or conclusions reached\n\
1519 4. Any action items or next steps if any\n\
1520 Format it in Markdown with headings and bullet points.";
1521
1522 let request = self.to_summarize_request(added_user_message.into());
1523
1524 let task = cx.spawn(async move |thread, cx| {
1525 let stream = model.stream_completion_text(request, &cx);
1526 let Some(mut messages) = stream.await.log_err() else {
1527 thread
1528 .update(cx, |this, _cx| {
1529 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1530 })
1531 .log_err();
1532
1533 return;
1534 };
1535
1536 let mut new_detailed_summary = String::new();
1537
1538 while let Some(chunk) = messages.stream.next().await {
1539 if let Some(chunk) = chunk.log_err() {
1540 new_detailed_summary.push_str(&chunk);
1541 }
1542 }
1543
1544 thread
1545 .update(cx, |this, _cx| {
1546 this.detailed_summary_state = DetailedSummaryState::Generated {
1547 text: new_detailed_summary.into(),
1548 message_id: last_message_id,
1549 };
1550 })
1551 .log_err();
1552 });
1553
1554 self.detailed_summary_state = DetailedSummaryState::Generating {
1555 message_id: last_message_id,
1556 };
1557
1558 Some(task)
1559 }
1560
1561 pub fn is_generating_detailed_summary(&self) -> bool {
1562 matches!(
1563 self.detailed_summary_state,
1564 DetailedSummaryState::Generating { .. }
1565 )
1566 }
1567
1568 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1569 self.auto_capture_telemetry(cx);
1570 let request = self.to_completion_request(cx);
1571 let messages = Arc::new(request.messages);
1572 let pending_tool_uses = self
1573 .tool_use
1574 .pending_tool_uses()
1575 .into_iter()
1576 .filter(|tool_use| tool_use.status.is_idle())
1577 .cloned()
1578 .collect::<Vec<_>>();
1579
1580 for tool_use in pending_tool_uses.iter() {
1581 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1582 if tool.needs_confirmation(&tool_use.input, cx)
1583 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1584 {
1585 self.tool_use.confirm_tool_use(
1586 tool_use.id.clone(),
1587 tool_use.ui_text.clone(),
1588 tool_use.input.clone(),
1589 messages.clone(),
1590 tool,
1591 );
1592 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1593 } else {
1594 self.run_tool(
1595 tool_use.id.clone(),
1596 tool_use.ui_text.clone(),
1597 tool_use.input.clone(),
1598 &messages,
1599 tool,
1600 cx,
1601 );
1602 }
1603 }
1604 }
1605
1606 pending_tool_uses
1607 }
1608
1609 pub fn run_tool(
1610 &mut self,
1611 tool_use_id: LanguageModelToolUseId,
1612 ui_text: impl Into<SharedString>,
1613 input: serde_json::Value,
1614 messages: &[LanguageModelRequestMessage],
1615 tool: Arc<dyn Tool>,
1616 cx: &mut Context<Thread>,
1617 ) {
1618 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1619 self.tool_use
1620 .run_pending_tool(tool_use_id, ui_text.into(), task);
1621 }
1622
1623 fn spawn_tool_use(
1624 &mut self,
1625 tool_use_id: LanguageModelToolUseId,
1626 messages: &[LanguageModelRequestMessage],
1627 input: serde_json::Value,
1628 tool: Arc<dyn Tool>,
1629 cx: &mut Context<Thread>,
1630 ) -> Task<()> {
1631 let tool_name: Arc<str> = tool.name().into();
1632
1633 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1634 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1635 } else {
1636 tool.run(
1637 input,
1638 messages,
1639 self.project.clone(),
1640 self.action_log.clone(),
1641 cx,
1642 )
1643 };
1644
1645 // Store the card separately if it exists
1646 if let Some(card) = tool_result.card.clone() {
1647 self.tool_use
1648 .insert_tool_result_card(tool_use_id.clone(), card);
1649 }
1650
1651 cx.spawn({
1652 async move |thread: WeakEntity<Thread>, cx| {
1653 let output = tool_result.output.await;
1654
1655 thread
1656 .update(cx, |thread, cx| {
1657 let pending_tool_use = thread.tool_use.insert_tool_output(
1658 tool_use_id.clone(),
1659 tool_name,
1660 output,
1661 cx,
1662 );
1663 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1664 })
1665 .ok();
1666 }
1667 })
1668 }
1669
1670 fn tool_finished(
1671 &mut self,
1672 tool_use_id: LanguageModelToolUseId,
1673 pending_tool_use: Option<PendingToolUse>,
1674 canceled: bool,
1675 cx: &mut Context<Self>,
1676 ) {
1677 if self.all_tools_finished() {
1678 let model_registry = LanguageModelRegistry::read_global(cx);
1679 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1680 self.attach_tool_results(cx);
1681 if !canceled {
1682 self.send_to_model(model, cx);
1683 }
1684 }
1685 }
1686
1687 cx.emit(ThreadEvent::ToolFinished {
1688 tool_use_id,
1689 pending_tool_use,
1690 });
1691 }
1692
1693 /// Insert an empty message to be populated with tool results upon send.
1694 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1695 // Tool results are assumed to be waiting on the next message id, so they will populate
1696 // this empty message before sending to model. Would prefer this to be more straightforward.
1697 self.insert_message(Role::User, vec![], cx);
1698 self.auto_capture_telemetry(cx);
1699 }
1700
1701 /// Cancels the last pending completion, if there are any pending.
1702 ///
1703 /// Returns whether a completion was canceled.
1704 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1705 let canceled = if self.pending_completions.pop().is_some() {
1706 true
1707 } else {
1708 let mut canceled = false;
1709 for pending_tool_use in self.tool_use.cancel_pending() {
1710 canceled = true;
1711 self.tool_finished(
1712 pending_tool_use.id.clone(),
1713 Some(pending_tool_use),
1714 true,
1715 cx,
1716 );
1717 }
1718 canceled
1719 };
1720 self.finalize_pending_checkpoint(cx);
1721 canceled
1722 }
1723
1724 pub fn feedback(&self) -> Option<ThreadFeedback> {
1725 self.feedback
1726 }
1727
1728 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1729 self.message_feedback.get(&message_id).copied()
1730 }
1731
1732 pub fn report_message_feedback(
1733 &mut self,
1734 message_id: MessageId,
1735 feedback: ThreadFeedback,
1736 cx: &mut Context<Self>,
1737 ) -> Task<Result<()>> {
1738 if self.message_feedback.get(&message_id) == Some(&feedback) {
1739 return Task::ready(Ok(()));
1740 }
1741
1742 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1743 let serialized_thread = self.serialize(cx);
1744 let thread_id = self.id().clone();
1745 let client = self.project.read(cx).client();
1746
1747 let enabled_tool_names: Vec<String> = self
1748 .tools()
1749 .read(cx)
1750 .enabled_tools(cx)
1751 .iter()
1752 .map(|tool| tool.name().to_string())
1753 .collect();
1754
1755 self.message_feedback.insert(message_id, feedback);
1756
1757 cx.notify();
1758
1759 let message_content = self
1760 .message(message_id)
1761 .map(|msg| msg.to_string())
1762 .unwrap_or_default();
1763
1764 cx.background_spawn(async move {
1765 let final_project_snapshot = final_project_snapshot.await;
1766 let serialized_thread = serialized_thread.await?;
1767 let thread_data =
1768 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1769
1770 let rating = match feedback {
1771 ThreadFeedback::Positive => "positive",
1772 ThreadFeedback::Negative => "negative",
1773 };
1774 telemetry::event!(
1775 "Assistant Thread Rated",
1776 rating,
1777 thread_id,
1778 enabled_tool_names,
1779 message_id = message_id.0,
1780 message_content,
1781 thread_data,
1782 final_project_snapshot
1783 );
1784 client.telemetry().flush_events().await;
1785
1786 Ok(())
1787 })
1788 }
1789
1790 pub fn report_feedback(
1791 &mut self,
1792 feedback: ThreadFeedback,
1793 cx: &mut Context<Self>,
1794 ) -> Task<Result<()>> {
1795 let last_assistant_message_id = self
1796 .messages
1797 .iter()
1798 .rev()
1799 .find(|msg| msg.role == Role::Assistant)
1800 .map(|msg| msg.id);
1801
1802 if let Some(message_id) = last_assistant_message_id {
1803 self.report_message_feedback(message_id, feedback, cx)
1804 } else {
1805 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1806 let serialized_thread = self.serialize(cx);
1807 let thread_id = self.id().clone();
1808 let client = self.project.read(cx).client();
1809 self.feedback = Some(feedback);
1810 cx.notify();
1811
1812 cx.background_spawn(async move {
1813 let final_project_snapshot = final_project_snapshot.await;
1814 let serialized_thread = serialized_thread.await?;
1815 let thread_data = serde_json::to_value(serialized_thread)
1816 .unwrap_or_else(|_| serde_json::Value::Null);
1817
1818 let rating = match feedback {
1819 ThreadFeedback::Positive => "positive",
1820 ThreadFeedback::Negative => "negative",
1821 };
1822 telemetry::event!(
1823 "Assistant Thread Rated",
1824 rating,
1825 thread_id,
1826 thread_data,
1827 final_project_snapshot
1828 );
1829 client.telemetry().flush_events().await;
1830
1831 Ok(())
1832 })
1833 }
1834 }
1835
1836 /// Create a snapshot of the current project state including git information and unsaved buffers.
1837 fn project_snapshot(
1838 project: Entity<Project>,
1839 cx: &mut Context<Self>,
1840 ) -> Task<Arc<ProjectSnapshot>> {
1841 let git_store = project.read(cx).git_store().clone();
1842 let worktree_snapshots: Vec<_> = project
1843 .read(cx)
1844 .visible_worktrees(cx)
1845 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1846 .collect();
1847
1848 cx.spawn(async move |_, cx| {
1849 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1850
1851 let mut unsaved_buffers = Vec::new();
1852 cx.update(|app_cx| {
1853 let buffer_store = project.read(app_cx).buffer_store();
1854 for buffer_handle in buffer_store.read(app_cx).buffers() {
1855 let buffer = buffer_handle.read(app_cx);
1856 if buffer.is_dirty() {
1857 if let Some(file) = buffer.file() {
1858 let path = file.path().to_string_lossy().to_string();
1859 unsaved_buffers.push(path);
1860 }
1861 }
1862 }
1863 })
1864 .ok();
1865
1866 Arc::new(ProjectSnapshot {
1867 worktree_snapshots,
1868 unsaved_buffer_paths: unsaved_buffers,
1869 timestamp: Utc::now(),
1870 })
1871 })
1872 }
1873
1874 fn worktree_snapshot(
1875 worktree: Entity<project::Worktree>,
1876 git_store: Entity<GitStore>,
1877 cx: &App,
1878 ) -> Task<WorktreeSnapshot> {
1879 cx.spawn(async move |cx| {
1880 // Get worktree path and snapshot
1881 let worktree_info = cx.update(|app_cx| {
1882 let worktree = worktree.read(app_cx);
1883 let path = worktree.abs_path().to_string_lossy().to_string();
1884 let snapshot = worktree.snapshot();
1885 (path, snapshot)
1886 });
1887
1888 let Ok((worktree_path, _snapshot)) = worktree_info else {
1889 return WorktreeSnapshot {
1890 worktree_path: String::new(),
1891 git_state: None,
1892 };
1893 };
1894
1895 let git_state = git_store
1896 .update(cx, |git_store, cx| {
1897 git_store
1898 .repositories()
1899 .values()
1900 .find(|repo| {
1901 repo.read(cx)
1902 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1903 .is_some()
1904 })
1905 .cloned()
1906 })
1907 .ok()
1908 .flatten()
1909 .map(|repo| {
1910 repo.update(cx, |repo, _| {
1911 let current_branch =
1912 repo.branch.as_ref().map(|branch| branch.name.to_string());
1913 repo.send_job(None, |state, _| async move {
1914 let RepositoryState::Local { backend, .. } = state else {
1915 return GitState {
1916 remote_url: None,
1917 head_sha: None,
1918 current_branch,
1919 diff: None,
1920 };
1921 };
1922
1923 let remote_url = backend.remote_url("origin");
1924 let head_sha = backend.head_sha();
1925 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1926
1927 GitState {
1928 remote_url,
1929 head_sha,
1930 current_branch,
1931 diff,
1932 }
1933 })
1934 })
1935 });
1936
1937 let git_state = match git_state {
1938 Some(git_state) => match git_state.ok() {
1939 Some(git_state) => git_state.await.ok(),
1940 None => None,
1941 },
1942 None => None,
1943 };
1944
1945 WorktreeSnapshot {
1946 worktree_path,
1947 git_state,
1948 }
1949 })
1950 }
1951
1952 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1953 let mut markdown = Vec::new();
1954
1955 if let Some(summary) = self.summary() {
1956 writeln!(markdown, "# {summary}\n")?;
1957 };
1958
1959 for message in self.messages() {
1960 writeln!(
1961 markdown,
1962 "## {role}\n",
1963 role = match message.role {
1964 Role::User => "User",
1965 Role::Assistant => "Assistant",
1966 Role::System => "System",
1967 }
1968 )?;
1969
1970 if !message.context.is_empty() {
1971 writeln!(markdown, "{}", message.context)?;
1972 }
1973
1974 for segment in &message.segments {
1975 match segment {
1976 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1977 MessageSegment::Thinking { text, .. } => {
1978 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1979 }
1980 MessageSegment::RedactedThinking(_) => {}
1981 }
1982 }
1983
1984 for tool_use in self.tool_uses_for_message(message.id, cx) {
1985 writeln!(
1986 markdown,
1987 "**Use Tool: {} ({})**",
1988 tool_use.name, tool_use.id
1989 )?;
1990 writeln!(markdown, "```json")?;
1991 writeln!(
1992 markdown,
1993 "{}",
1994 serde_json::to_string_pretty(&tool_use.input)?
1995 )?;
1996 writeln!(markdown, "```")?;
1997 }
1998
1999 for tool_result in self.tool_results_for_message(message.id) {
2000 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2001 if tool_result.is_error {
2002 write!(markdown, " (Error)")?;
2003 }
2004
2005 writeln!(markdown, "**\n")?;
2006 writeln!(markdown, "{}", tool_result.content)?;
2007 }
2008 }
2009
2010 Ok(String::from_utf8_lossy(&markdown).to_string())
2011 }
2012
2013 pub fn keep_edits_in_range(
2014 &mut self,
2015 buffer: Entity<language::Buffer>,
2016 buffer_range: Range<language::Anchor>,
2017 cx: &mut Context<Self>,
2018 ) {
2019 self.action_log.update(cx, |action_log, cx| {
2020 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2021 });
2022 }
2023
2024 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2025 self.action_log
2026 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2027 }
2028
2029 pub fn reject_edits_in_ranges(
2030 &mut self,
2031 buffer: Entity<language::Buffer>,
2032 buffer_ranges: Vec<Range<language::Anchor>>,
2033 cx: &mut Context<Self>,
2034 ) -> Task<Result<()>> {
2035 self.action_log.update(cx, |action_log, cx| {
2036 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2037 })
2038 }
2039
2040 pub fn action_log(&self) -> &Entity<ActionLog> {
2041 &self.action_log
2042 }
2043
2044 pub fn project(&self) -> &Entity<Project> {
2045 &self.project
2046 }
2047
2048 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2049 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2050 return;
2051 }
2052
2053 let now = Instant::now();
2054 if let Some(last) = self.last_auto_capture_at {
2055 if now.duration_since(last).as_secs() < 10 {
2056 return;
2057 }
2058 }
2059
2060 self.last_auto_capture_at = Some(now);
2061
2062 let thread_id = self.id().clone();
2063 let github_login = self
2064 .project
2065 .read(cx)
2066 .user_store()
2067 .read(cx)
2068 .current_user()
2069 .map(|user| user.github_login.clone());
2070 let client = self.project.read(cx).client().clone();
2071 let serialize_task = self.serialize(cx);
2072
2073 cx.background_executor()
2074 .spawn(async move {
2075 if let Ok(serialized_thread) = serialize_task.await {
2076 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2077 telemetry::event!(
2078 "Agent Thread Auto-Captured",
2079 thread_id = thread_id.to_string(),
2080 thread_data = thread_data,
2081 auto_capture_reason = "tracked_user",
2082 github_login = github_login
2083 );
2084
2085 client.telemetry().flush_events().await;
2086 }
2087 }
2088 })
2089 .detach();
2090 }
2091
2092 pub fn cumulative_token_usage(&self) -> TokenUsage {
2093 self.cumulative_token_usage
2094 }
2095
2096 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2097 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2098 return TotalTokenUsage::default();
2099 };
2100
2101 let max = model.model.max_token_count();
2102
2103 let index = self
2104 .messages
2105 .iter()
2106 .position(|msg| msg.id == message_id)
2107 .unwrap_or(0);
2108
2109 if index == 0 {
2110 return TotalTokenUsage { total: 0, max };
2111 }
2112
2113 let token_usage = &self
2114 .request_token_usage
2115 .get(index - 1)
2116 .cloned()
2117 .unwrap_or_default();
2118
2119 TotalTokenUsage {
2120 total: token_usage.total_tokens() as usize,
2121 max,
2122 }
2123 }
2124
2125 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2126 let model_registry = LanguageModelRegistry::read_global(cx);
2127 let Some(model) = model_registry.default_model() else {
2128 return TotalTokenUsage::default();
2129 };
2130
2131 let max = model.model.max_token_count();
2132
2133 if let Some(exceeded_error) = &self.exceeded_window_error {
2134 if model.model.id() == exceeded_error.model_id {
2135 return TotalTokenUsage {
2136 total: exceeded_error.token_count,
2137 max,
2138 };
2139 }
2140 }
2141
2142 let total = self
2143 .token_usage_at_last_message()
2144 .unwrap_or_default()
2145 .total_tokens() as usize;
2146
2147 TotalTokenUsage { total, max }
2148 }
2149
2150 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2151 self.request_token_usage
2152 .get(self.messages.len().saturating_sub(1))
2153 .or_else(|| self.request_token_usage.last())
2154 .cloned()
2155 }
2156
2157 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2158 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2159 self.request_token_usage
2160 .resize(self.messages.len(), placeholder);
2161
2162 if let Some(last) = self.request_token_usage.last_mut() {
2163 *last = token_usage;
2164 }
2165 }
2166
2167 pub fn deny_tool_use(
2168 &mut self,
2169 tool_use_id: LanguageModelToolUseId,
2170 tool_name: Arc<str>,
2171 cx: &mut Context<Self>,
2172 ) {
2173 let err = Err(anyhow::anyhow!(
2174 "Permission to run tool action denied by user"
2175 ));
2176
2177 self.tool_use
2178 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2179 self.tool_finished(tool_use_id.clone(), None, true, cx);
2180 }
2181}
2182
2183#[derive(Debug, Clone, Error)]
2184pub enum ThreadError {
2185 #[error("Payment required")]
2186 PaymentRequired,
2187 #[error("Max monthly spend reached")]
2188 MaxMonthlySpendReached,
2189 #[error("Model request limit reached")]
2190 ModelRequestLimitReached { plan: Plan },
2191 #[error("Message {header}: {message}")]
2192 Message {
2193 header: SharedString,
2194 message: SharedString,
2195 },
2196}
2197
2198#[derive(Debug, Clone)]
2199pub enum ThreadEvent {
2200 ShowError(ThreadError),
2201 UsageUpdated(RequestUsage),
2202 StreamedCompletion,
2203 ReceivedTextChunk,
2204 StreamedAssistantText(MessageId, String),
2205 StreamedAssistantThinking(MessageId, String),
2206 StreamedToolUse {
2207 tool_use_id: LanguageModelToolUseId,
2208 ui_text: Arc<str>,
2209 input: serde_json::Value,
2210 },
2211 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2212 MessageAdded(MessageId),
2213 MessageEdited(MessageId),
2214 MessageDeleted(MessageId),
2215 SummaryGenerated,
2216 SummaryChanged,
2217 UsePendingTools {
2218 tool_uses: Vec<PendingToolUse>,
2219 },
2220 ToolFinished {
2221 #[allow(unused)]
2222 tool_use_id: LanguageModelToolUseId,
2223 /// The pending tool use that corresponds to this tool.
2224 pending_tool_use: Option<PendingToolUse>,
2225 },
2226 CheckpointChanged,
2227 ToolConfirmationNeeded,
2228}
2229
2230impl EventEmitter<ThreadEvent> for Thread {}
2231
2232struct PendingCompletion {
2233 id: usize,
2234 _task: Task<()>,
2235}
2236
2237#[cfg(test)]
2238mod tests {
2239 use super::*;
2240 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2241 use assistant_settings::AssistantSettings;
2242 use context_server::ContextServerSettings;
2243 use editor::EditorSettings;
2244 use gpui::TestAppContext;
2245 use project::{FakeFs, Project};
2246 use prompt_store::PromptBuilder;
2247 use serde_json::json;
2248 use settings::{Settings, SettingsStore};
2249 use std::sync::Arc;
2250 use theme::ThemeSettings;
2251 use util::path;
2252 use workspace::Workspace;
2253
2254 #[gpui::test]
2255 async fn test_message_with_context(cx: &mut TestAppContext) {
2256 init_test_settings(cx);
2257
2258 let project = create_test_project(
2259 cx,
2260 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2261 )
2262 .await;
2263
2264 let (_workspace, _thread_store, thread, context_store) =
2265 setup_test_environment(cx, project.clone()).await;
2266
2267 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2268 .await
2269 .unwrap();
2270
2271 let context =
2272 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2273
2274 // Insert user message with context
2275 let message_id = thread.update(cx, |thread, cx| {
2276 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2277 });
2278
2279 // Check content and context in message object
2280 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2281
2282 // Use different path format strings based on platform for the test
2283 #[cfg(windows)]
2284 let path_part = r"test\code.rs";
2285 #[cfg(not(windows))]
2286 let path_part = "test/code.rs";
2287
2288 let expected_context = format!(
2289 r#"
2290<context>
2291The following items were attached by the user. You don't need to use other tools to read them.
2292
2293<files>
2294```rs {path_part}
2295fn main() {{
2296 println!("Hello, world!");
2297}}
2298```
2299</files>
2300</context>
2301"#
2302 );
2303
2304 assert_eq!(message.role, Role::User);
2305 assert_eq!(message.segments.len(), 1);
2306 assert_eq!(
2307 message.segments[0],
2308 MessageSegment::Text("Please explain this code".to_string())
2309 );
2310 assert_eq!(message.context, expected_context);
2311
2312 // Check message in request
2313 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2314
2315 assert_eq!(request.messages.len(), 2);
2316 let expected_full_message = format!("{}Please explain this code", expected_context);
2317 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2318 }
2319
2320 #[gpui::test]
2321 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2322 init_test_settings(cx);
2323
2324 let project = create_test_project(
2325 cx,
2326 json!({
2327 "file1.rs": "fn function1() {}\n",
2328 "file2.rs": "fn function2() {}\n",
2329 "file3.rs": "fn function3() {}\n",
2330 }),
2331 )
2332 .await;
2333
2334 let (_, _thread_store, thread, context_store) =
2335 setup_test_environment(cx, project.clone()).await;
2336
2337 // Open files individually
2338 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2339 .await
2340 .unwrap();
2341 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2342 .await
2343 .unwrap();
2344 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2345 .await
2346 .unwrap();
2347
2348 // Get the context objects
2349 let contexts = context_store.update(cx, |store, _| store.context().clone());
2350 assert_eq!(contexts.len(), 3);
2351
2352 // First message with context 1
2353 let message1_id = thread.update(cx, |thread, cx| {
2354 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2355 });
2356
2357 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2358 let message2_id = thread.update(cx, |thread, cx| {
2359 thread.insert_user_message(
2360 "Message 2",
2361 vec![contexts[0].clone(), contexts[1].clone()],
2362 None,
2363 cx,
2364 )
2365 });
2366
2367 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2368 let message3_id = thread.update(cx, |thread, cx| {
2369 thread.insert_user_message(
2370 "Message 3",
2371 vec![
2372 contexts[0].clone(),
2373 contexts[1].clone(),
2374 contexts[2].clone(),
2375 ],
2376 None,
2377 cx,
2378 )
2379 });
2380
2381 // Check what contexts are included in each message
2382 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2383 (
2384 thread.message(message1_id).unwrap().clone(),
2385 thread.message(message2_id).unwrap().clone(),
2386 thread.message(message3_id).unwrap().clone(),
2387 )
2388 });
2389
2390 // First message should include context 1
2391 assert!(message1.context.contains("file1.rs"));
2392
2393 // Second message should include only context 2 (not 1)
2394 assert!(!message2.context.contains("file1.rs"));
2395 assert!(message2.context.contains("file2.rs"));
2396
2397 // Third message should include only context 3 (not 1 or 2)
2398 assert!(!message3.context.contains("file1.rs"));
2399 assert!(!message3.context.contains("file2.rs"));
2400 assert!(message3.context.contains("file3.rs"));
2401
2402 // Check entire request to make sure all contexts are properly included
2403 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2404
2405 // The request should contain all 3 messages
2406 assert_eq!(request.messages.len(), 4);
2407
2408 // Check that the contexts are properly formatted in each message
2409 assert!(request.messages[1].string_contents().contains("file1.rs"));
2410 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2411 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2412
2413 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2414 assert!(request.messages[2].string_contents().contains("file2.rs"));
2415 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2416
2417 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2418 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2419 assert!(request.messages[3].string_contents().contains("file3.rs"));
2420 }
2421
2422 #[gpui::test]
2423 async fn test_message_without_files(cx: &mut TestAppContext) {
2424 init_test_settings(cx);
2425
2426 let project = create_test_project(
2427 cx,
2428 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2429 )
2430 .await;
2431
2432 let (_, _thread_store, thread, _context_store) =
2433 setup_test_environment(cx, project.clone()).await;
2434
2435 // Insert user message without any context (empty context vector)
2436 let message_id = thread.update(cx, |thread, cx| {
2437 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2438 });
2439
2440 // Check content and context in message object
2441 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2442
2443 // Context should be empty when no files are included
2444 assert_eq!(message.role, Role::User);
2445 assert_eq!(message.segments.len(), 1);
2446 assert_eq!(
2447 message.segments[0],
2448 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2449 );
2450 assert_eq!(message.context, "");
2451
2452 // Check message in request
2453 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2454
2455 assert_eq!(request.messages.len(), 2);
2456 assert_eq!(
2457 request.messages[1].string_contents(),
2458 "What is the best way to learn Rust?"
2459 );
2460
2461 // Add second message, also without context
2462 let message2_id = thread.update(cx, |thread, cx| {
2463 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2464 });
2465
2466 let message2 =
2467 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2468 assert_eq!(message2.context, "");
2469
2470 // Check that both messages appear in the request
2471 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2472
2473 assert_eq!(request.messages.len(), 3);
2474 assert_eq!(
2475 request.messages[1].string_contents(),
2476 "What is the best way to learn Rust?"
2477 );
2478 assert_eq!(
2479 request.messages[2].string_contents(),
2480 "Are there any good books?"
2481 );
2482 }
2483
2484 #[gpui::test]
2485 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2486 init_test_settings(cx);
2487
2488 let project = create_test_project(
2489 cx,
2490 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2491 )
2492 .await;
2493
2494 let (_workspace, _thread_store, thread, context_store) =
2495 setup_test_environment(cx, project.clone()).await;
2496
2497 // Open buffer and add it to context
2498 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2499 .await
2500 .unwrap();
2501
2502 let context =
2503 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2504
2505 // Insert user message with the buffer as context
2506 thread.update(cx, |thread, cx| {
2507 thread.insert_user_message("Explain this code", vec![context], None, cx)
2508 });
2509
2510 // Create a request and check that it doesn't have a stale buffer warning yet
2511 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2512
2513 // Make sure we don't have a stale file warning yet
2514 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2515 msg.string_contents()
2516 .contains("These files changed since last read:")
2517 });
2518 assert!(
2519 !has_stale_warning,
2520 "Should not have stale buffer warning before buffer is modified"
2521 );
2522
2523 // Modify the buffer
2524 buffer.update(cx, |buffer, cx| {
2525 // Find a position at the end of line 1
2526 buffer.edit(
2527 [(1..1, "\n println!(\"Added a new line\");\n")],
2528 None,
2529 cx,
2530 );
2531 });
2532
2533 // Insert another user message without context
2534 thread.update(cx, |thread, cx| {
2535 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2536 });
2537
2538 // Create a new request and check for the stale buffer warning
2539 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2540
2541 // We should have a stale file warning as the last message
2542 let last_message = new_request
2543 .messages
2544 .last()
2545 .expect("Request should have messages");
2546
2547 // The last message should be the stale buffer notification
2548 assert_eq!(last_message.role, Role::User);
2549
2550 // Check the exact content of the message
2551 let expected_content = "These files changed since last read:\n- code.rs\n";
2552 assert_eq!(
2553 last_message.string_contents(),
2554 expected_content,
2555 "Last message should be exactly the stale buffer notification"
2556 );
2557 }
2558
2559 fn init_test_settings(cx: &mut TestAppContext) {
2560 cx.update(|cx| {
2561 let settings_store = SettingsStore::test(cx);
2562 cx.set_global(settings_store);
2563 language::init(cx);
2564 Project::init_settings(cx);
2565 AssistantSettings::register(cx);
2566 prompt_store::init(cx);
2567 thread_store::init(cx);
2568 workspace::init_settings(cx);
2569 ThemeSettings::register(cx);
2570 ContextServerSettings::register(cx);
2571 EditorSettings::register(cx);
2572 });
2573 }
2574
2575 // Helper to create a test project with test files
2576 async fn create_test_project(
2577 cx: &mut TestAppContext,
2578 files: serde_json::Value,
2579 ) -> Entity<Project> {
2580 let fs = FakeFs::new(cx.executor());
2581 fs.insert_tree(path!("/test"), files).await;
2582 Project::test(fs, [path!("/test").as_ref()], cx).await
2583 }
2584
2585 async fn setup_test_environment(
2586 cx: &mut TestAppContext,
2587 project: Entity<Project>,
2588 ) -> (
2589 Entity<Workspace>,
2590 Entity<ThreadStore>,
2591 Entity<Thread>,
2592 Entity<ContextStore>,
2593 ) {
2594 let (workspace, cx) =
2595 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2596
2597 let thread_store = cx
2598 .update(|_, cx| {
2599 ThreadStore::load(
2600 project.clone(),
2601 cx.new(|_| ToolWorkingSet::default()),
2602 Arc::new(PromptBuilder::new(None).unwrap()),
2603 cx,
2604 )
2605 })
2606 .await
2607 .unwrap();
2608
2609 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2610 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2611
2612 (workspace, thread_store, thread, context_store)
2613 }
2614
2615 async fn add_file_to_context(
2616 project: &Entity<Project>,
2617 context_store: &Entity<ContextStore>,
2618 path: &str,
2619 cx: &mut TestAppContext,
2620 ) -> Result<Entity<language::Buffer>> {
2621 let buffer_path = project
2622 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2623 .unwrap();
2624
2625 let buffer = project
2626 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2627 .await
2628 .unwrap();
2629
2630 context_store
2631 .update(cx, |store, cx| {
2632 store.add_file_from_buffer(buffer.clone(), cx)
2633 })
2634 .await?;
2635
2636 Ok(buffer)
2637 }
2638}