1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// A message in a [`Thread`].
101#[derive(Debug, Clone)]
102pub struct Message {
103 pub id: MessageId,
104 pub role: Role,
105 pub segments: Vec<MessageSegment>,
106 pub context: String,
107}
108
109impl Message {
110 /// Returns whether the message contains any meaningful text that should be displayed
111 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
112 pub fn should_display_content(&self) -> bool {
113 self.segments.iter().all(|segment| segment.should_display())
114 }
115
116 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
117 if let Some(MessageSegment::Thinking {
118 text: segment,
119 signature: current_signature,
120 }) = self.segments.last_mut()
121 {
122 if let Some(signature) = signature {
123 *current_signature = Some(signature);
124 }
125 segment.push_str(text);
126 } else {
127 self.segments.push(MessageSegment::Thinking {
128 text: text.to_string(),
129 signature,
130 });
131 }
132 }
133
134 pub fn push_text(&mut self, text: &str) {
135 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Text(text.to_string()));
139 }
140 }
141
142 pub fn to_string(&self) -> String {
143 let mut result = String::new();
144
145 if !self.context.is_empty() {
146 result.push_str(&self.context);
147 }
148
149 for segment in &self.segments {
150 match segment {
151 MessageSegment::Text(text) => result.push_str(text),
152 MessageSegment::Thinking { text, .. } => {
153 result.push_str("<think>\n");
154 result.push_str(text);
155 result.push_str("\n</think>");
156 }
157 MessageSegment::RedactedThinking(_) => {}
158 }
159 }
160
161 result
162 }
163}
164
165#[derive(Debug, Clone, PartialEq, Eq)]
166pub enum MessageSegment {
167 Text(String),
168 Thinking {
169 text: String,
170 signature: Option<String>,
171 },
172 RedactedThinking(Vec<u8>),
173}
174
175impl MessageSegment {
176 pub fn should_display(&self) -> bool {
177 // We add USING_TOOL_MARKER when making a request that includes tool uses
178 // without non-whitespace text around them, and this can cause the model
179 // to mimic the pattern, so we consider those segments not displayable.
180 match self {
181 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
182 Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
183 Self::RedactedThinking(_) => false,
184 }
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ProjectSnapshot {
190 pub worktree_snapshots: Vec<WorktreeSnapshot>,
191 pub unsaved_buffer_paths: Vec<String>,
192 pub timestamp: DateTime<Utc>,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct WorktreeSnapshot {
197 pub worktree_path: String,
198 pub git_state: Option<GitState>,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct GitState {
203 pub remote_url: Option<String>,
204 pub head_sha: Option<String>,
205 pub current_branch: Option<String>,
206 pub diff: Option<String>,
207}
208
209#[derive(Clone)]
210pub struct ThreadCheckpoint {
211 message_id: MessageId,
212 git_checkpoint: GitStoreCheckpoint,
213}
214
215#[derive(Copy, Clone, Debug, PartialEq, Eq)]
216pub enum ThreadFeedback {
217 Positive,
218 Negative,
219}
220
221pub enum LastRestoreCheckpoint {
222 Pending {
223 message_id: MessageId,
224 },
225 Error {
226 message_id: MessageId,
227 error: String,
228 },
229}
230
231impl LastRestoreCheckpoint {
232 pub fn message_id(&self) -> MessageId {
233 match self {
234 LastRestoreCheckpoint::Pending { message_id } => *message_id,
235 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
236 }
237 }
238}
239
240#[derive(Clone, Debug, Default, Serialize, Deserialize)]
241pub enum DetailedSummaryState {
242 #[default]
243 NotGenerated,
244 Generating {
245 message_id: MessageId,
246 },
247 Generated {
248 text: SharedString,
249 message_id: MessageId,
250 },
251}
252
253#[derive(Default)]
254pub struct TotalTokenUsage {
255 pub total: usize,
256 pub max: usize,
257}
258
259impl TotalTokenUsage {
260 pub fn ratio(&self) -> TokenUsageRatio {
261 #[cfg(debug_assertions)]
262 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
263 .unwrap_or("0.8".to_string())
264 .parse()
265 .unwrap();
266 #[cfg(not(debug_assertions))]
267 let warning_threshold: f32 = 0.8;
268
269 if self.total >= self.max {
270 TokenUsageRatio::Exceeded
271 } else if self.total as f32 / self.max as f32 >= warning_threshold {
272 TokenUsageRatio::Warning
273 } else {
274 TokenUsageRatio::Normal
275 }
276 }
277
278 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
279 TotalTokenUsage {
280 total: self.total + tokens,
281 max: self.max,
282 }
283 }
284}
285
286#[derive(Debug, Default, PartialEq, Eq)]
287pub enum TokenUsageRatio {
288 #[default]
289 Normal,
290 Warning,
291 Exceeded,
292}
293
294/// A thread of conversation with the LLM.
295pub struct Thread {
296 id: ThreadId,
297 updated_at: DateTime<Utc>,
298 summary: Option<SharedString>,
299 pending_summary: Task<Option<()>>,
300 detailed_summary_state: DetailedSummaryState,
301 messages: Vec<Message>,
302 next_message_id: MessageId,
303 last_prompt_id: PromptId,
304 context: BTreeMap<ContextId, AssistantContext>,
305 context_by_message: HashMap<MessageId, Vec<ContextId>>,
306 project_context: SharedProjectContext,
307 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
308 completion_count: usize,
309 pending_completions: Vec<PendingCompletion>,
310 project: Entity<Project>,
311 prompt_builder: Arc<PromptBuilder>,
312 tools: Entity<ToolWorkingSet>,
313 tool_use: ToolUseState,
314 action_log: Entity<ActionLog>,
315 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
316 pending_checkpoint: Option<ThreadCheckpoint>,
317 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
318 request_token_usage: Vec<TokenUsage>,
319 cumulative_token_usage: TokenUsage,
320 exceeded_window_error: Option<ExceededWindowError>,
321 feedback: Option<ThreadFeedback>,
322 message_feedback: HashMap<MessageId, ThreadFeedback>,
323 last_auto_capture_at: Option<Instant>,
324 request_callback: Option<
325 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
326 >,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct ExceededWindowError {
331 /// Model used when last message exceeded context window
332 model_id: LanguageModelId,
333 /// Token count including last message
334 token_count: usize,
335}
336
337impl Thread {
338 pub fn new(
339 project: Entity<Project>,
340 tools: Entity<ToolWorkingSet>,
341 prompt_builder: Arc<PromptBuilder>,
342 system_prompt: SharedProjectContext,
343 cx: &mut Context<Self>,
344 ) -> Self {
345 Self {
346 id: ThreadId::new(),
347 updated_at: Utc::now(),
348 summary: None,
349 pending_summary: Task::ready(None),
350 detailed_summary_state: DetailedSummaryState::NotGenerated,
351 messages: Vec::new(),
352 next_message_id: MessageId(0),
353 last_prompt_id: PromptId::new(),
354 context: BTreeMap::default(),
355 context_by_message: HashMap::default(),
356 project_context: system_prompt,
357 checkpoints_by_message: HashMap::default(),
358 completion_count: 0,
359 pending_completions: Vec::new(),
360 project: project.clone(),
361 prompt_builder,
362 tools: tools.clone(),
363 last_restore_checkpoint: None,
364 pending_checkpoint: None,
365 tool_use: ToolUseState::new(tools.clone()),
366 action_log: cx.new(|_| ActionLog::new(project.clone())),
367 initial_project_snapshot: {
368 let project_snapshot = Self::project_snapshot(project, cx);
369 cx.foreground_executor()
370 .spawn(async move { Some(project_snapshot.await) })
371 .shared()
372 },
373 request_token_usage: Vec::new(),
374 cumulative_token_usage: TokenUsage::default(),
375 exceeded_window_error: None,
376 feedback: None,
377 message_feedback: HashMap::default(),
378 last_auto_capture_at: None,
379 request_callback: None,
380 }
381 }
382
383 pub fn deserialize(
384 id: ThreadId,
385 serialized: SerializedThread,
386 project: Entity<Project>,
387 tools: Entity<ToolWorkingSet>,
388 prompt_builder: Arc<PromptBuilder>,
389 project_context: SharedProjectContext,
390 cx: &mut Context<Self>,
391 ) -> Self {
392 let next_message_id = MessageId(
393 serialized
394 .messages
395 .last()
396 .map(|message| message.id.0 + 1)
397 .unwrap_or(0),
398 );
399 let tool_use =
400 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
401
402 Self {
403 id,
404 updated_at: serialized.updated_at,
405 summary: Some(serialized.summary),
406 pending_summary: Task::ready(None),
407 detailed_summary_state: serialized.detailed_summary_state,
408 messages: serialized
409 .messages
410 .into_iter()
411 .map(|message| Message {
412 id: message.id,
413 role: message.role,
414 segments: message
415 .segments
416 .into_iter()
417 .map(|segment| match segment {
418 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
419 SerializedMessageSegment::Thinking { text, signature } => {
420 MessageSegment::Thinking { text, signature }
421 }
422 SerializedMessageSegment::RedactedThinking { data } => {
423 MessageSegment::RedactedThinking(data)
424 }
425 })
426 .collect(),
427 context: message.context,
428 })
429 .collect(),
430 next_message_id,
431 last_prompt_id: PromptId::new(),
432 context: BTreeMap::default(),
433 context_by_message: HashMap::default(),
434 project_context,
435 checkpoints_by_message: HashMap::default(),
436 completion_count: 0,
437 pending_completions: Vec::new(),
438 last_restore_checkpoint: None,
439 pending_checkpoint: None,
440 project: project.clone(),
441 prompt_builder,
442 tools,
443 tool_use,
444 action_log: cx.new(|_| ActionLog::new(project)),
445 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
446 request_token_usage: serialized.request_token_usage,
447 cumulative_token_usage: serialized.cumulative_token_usage,
448 exceeded_window_error: None,
449 feedback: None,
450 message_feedback: HashMap::default(),
451 last_auto_capture_at: None,
452 request_callback: None,
453 }
454 }
455
456 pub fn set_request_callback(
457 &mut self,
458 callback: impl 'static
459 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
460 ) {
461 self.request_callback = Some(Box::new(callback));
462 }
463
464 pub fn id(&self) -> &ThreadId {
465 &self.id
466 }
467
468 pub fn is_empty(&self) -> bool {
469 self.messages.is_empty()
470 }
471
472 pub fn updated_at(&self) -> DateTime<Utc> {
473 self.updated_at
474 }
475
476 pub fn touch_updated_at(&mut self) {
477 self.updated_at = Utc::now();
478 }
479
480 pub fn advance_prompt_id(&mut self) {
481 self.last_prompt_id = PromptId::new();
482 }
483
484 pub fn summary(&self) -> Option<SharedString> {
485 self.summary.clone()
486 }
487
488 pub fn project_context(&self) -> SharedProjectContext {
489 self.project_context.clone()
490 }
491
492 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
493
494 pub fn summary_or_default(&self) -> SharedString {
495 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
496 }
497
498 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
499 let Some(current_summary) = &self.summary else {
500 // Don't allow setting summary until generated
501 return;
502 };
503
504 let mut new_summary = new_summary.into();
505
506 if new_summary.is_empty() {
507 new_summary = Self::DEFAULT_SUMMARY;
508 }
509
510 if current_summary != &new_summary {
511 self.summary = Some(new_summary);
512 cx.emit(ThreadEvent::SummaryChanged);
513 }
514 }
515
516 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
517 self.latest_detailed_summary()
518 .unwrap_or_else(|| self.text().into())
519 }
520
521 fn latest_detailed_summary(&self) -> Option<SharedString> {
522 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
523 Some(text.clone())
524 } else {
525 None
526 }
527 }
528
529 pub fn message(&self, id: MessageId) -> Option<&Message> {
530 self.messages.iter().find(|message| message.id == id)
531 }
532
533 pub fn messages(&self) -> impl Iterator<Item = &Message> {
534 self.messages.iter()
535 }
536
537 pub fn is_generating(&self) -> bool {
538 !self.pending_completions.is_empty() || !self.all_tools_finished()
539 }
540
541 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
542 &self.tools
543 }
544
545 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
546 self.tool_use
547 .pending_tool_uses()
548 .into_iter()
549 .find(|tool_use| &tool_use.id == id)
550 }
551
552 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
553 self.tool_use
554 .pending_tool_uses()
555 .into_iter()
556 .filter(|tool_use| tool_use.status.needs_confirmation())
557 }
558
559 pub fn has_pending_tool_uses(&self) -> bool {
560 !self.tool_use.pending_tool_uses().is_empty()
561 }
562
563 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
564 self.checkpoints_by_message.get(&id).cloned()
565 }
566
567 pub fn restore_checkpoint(
568 &mut self,
569 checkpoint: ThreadCheckpoint,
570 cx: &mut Context<Self>,
571 ) -> Task<Result<()>> {
572 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
573 message_id: checkpoint.message_id,
574 });
575 cx.emit(ThreadEvent::CheckpointChanged);
576 cx.notify();
577
578 let git_store = self.project().read(cx).git_store().clone();
579 let restore = git_store.update(cx, |git_store, cx| {
580 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
581 });
582
583 cx.spawn(async move |this, cx| {
584 let result = restore.await;
585 this.update(cx, |this, cx| {
586 if let Err(err) = result.as_ref() {
587 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
588 message_id: checkpoint.message_id,
589 error: err.to_string(),
590 });
591 } else {
592 this.truncate(checkpoint.message_id, cx);
593 this.last_restore_checkpoint = None;
594 }
595 this.pending_checkpoint = None;
596 cx.emit(ThreadEvent::CheckpointChanged);
597 cx.notify();
598 })?;
599 result
600 })
601 }
602
603 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
604 let pending_checkpoint = if self.is_generating() {
605 return;
606 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
607 checkpoint
608 } else {
609 return;
610 };
611
612 let git_store = self.project.read(cx).git_store().clone();
613 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
614 cx.spawn(async move |this, cx| match final_checkpoint.await {
615 Ok(final_checkpoint) => {
616 let equal = git_store
617 .update(cx, |store, cx| {
618 store.compare_checkpoints(
619 pending_checkpoint.git_checkpoint.clone(),
620 final_checkpoint.clone(),
621 cx,
622 )
623 })?
624 .await
625 .unwrap_or(false);
626
627 if equal {
628 git_store
629 .update(cx, |store, cx| {
630 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
631 })?
632 .detach();
633 } else {
634 this.update(cx, |this, cx| {
635 this.insert_checkpoint(pending_checkpoint, cx)
636 })?;
637 }
638
639 git_store
640 .update(cx, |store, cx| {
641 store.delete_checkpoint(final_checkpoint, cx)
642 })?
643 .detach();
644
645 Ok(())
646 }
647 Err(_) => this.update(cx, |this, cx| {
648 this.insert_checkpoint(pending_checkpoint, cx)
649 }),
650 })
651 .detach();
652 }
653
654 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
655 self.checkpoints_by_message
656 .insert(checkpoint.message_id, checkpoint);
657 cx.emit(ThreadEvent::CheckpointChanged);
658 cx.notify();
659 }
660
661 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
662 self.last_restore_checkpoint.as_ref()
663 }
664
665 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
666 let Some(message_ix) = self
667 .messages
668 .iter()
669 .rposition(|message| message.id == message_id)
670 else {
671 return;
672 };
673 for deleted_message in self.messages.drain(message_ix..) {
674 self.context_by_message.remove(&deleted_message.id);
675 self.checkpoints_by_message.remove(&deleted_message.id);
676 }
677 cx.notify();
678 }
679
680 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
681 self.context_by_message
682 .get(&id)
683 .into_iter()
684 .flat_map(|context| {
685 context
686 .iter()
687 .filter_map(|context_id| self.context.get(&context_id))
688 })
689 }
690
691 /// Returns whether all of the tool uses have finished running.
692 pub fn all_tools_finished(&self) -> bool {
693 // If the only pending tool uses left are the ones with errors, then
694 // that means that we've finished running all of the pending tools.
695 self.tool_use
696 .pending_tool_uses()
697 .iter()
698 .all(|tool_use| tool_use.status.is_error())
699 }
700
701 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
702 self.tool_use.tool_uses_for_message(id, cx)
703 }
704
705 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
706 self.tool_use.tool_results_for_message(id)
707 }
708
709 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
710 self.tool_use.tool_result(id)
711 }
712
713 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
714 Some(&self.tool_use.tool_result(id)?.content)
715 }
716
717 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
718 self.tool_use.tool_result_card(id).cloned()
719 }
720
721 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
722 self.tool_use.message_has_tool_results(message_id)
723 }
724
725 /// Filter out contexts that have already been included in previous messages
726 pub fn filter_new_context<'a>(
727 &self,
728 context: impl Iterator<Item = &'a AssistantContext>,
729 ) -> impl Iterator<Item = &'a AssistantContext> {
730 context.filter(|ctx| self.is_context_new(ctx))
731 }
732
733 fn is_context_new(&self, context: &AssistantContext) -> bool {
734 !self.context.contains_key(&context.id())
735 }
736
737 pub fn insert_user_message(
738 &mut self,
739 text: impl Into<String>,
740 context: Vec<AssistantContext>,
741 git_checkpoint: Option<GitStoreCheckpoint>,
742 cx: &mut Context<Self>,
743 ) -> MessageId {
744 let text = text.into();
745
746 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
747
748 let new_context: Vec<_> = context
749 .into_iter()
750 .filter(|ctx| self.is_context_new(ctx))
751 .collect();
752
753 if !new_context.is_empty() {
754 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
755 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
756 message.context = context_string;
757 }
758 }
759
760 self.action_log.update(cx, |log, cx| {
761 // Track all buffers added as context
762 for ctx in &new_context {
763 match ctx {
764 AssistantContext::File(file_ctx) => {
765 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
766 }
767 AssistantContext::Directory(dir_ctx) => {
768 for context_buffer in &dir_ctx.context_buffers {
769 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
770 }
771 }
772 AssistantContext::Symbol(symbol_ctx) => {
773 log.buffer_added_as_context(
774 symbol_ctx.context_symbol.buffer.clone(),
775 cx,
776 );
777 }
778 AssistantContext::Excerpt(excerpt_context) => {
779 log.buffer_added_as_context(
780 excerpt_context.context_buffer.buffer.clone(),
781 cx,
782 );
783 }
784 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
785 }
786 }
787 });
788 }
789
790 let context_ids = new_context
791 .iter()
792 .map(|context| context.id())
793 .collect::<Vec<_>>();
794 self.context.extend(
795 new_context
796 .into_iter()
797 .map(|context| (context.id(), context)),
798 );
799 self.context_by_message.insert(message_id, context_ids);
800
801 if let Some(git_checkpoint) = git_checkpoint {
802 self.pending_checkpoint = Some(ThreadCheckpoint {
803 message_id,
804 git_checkpoint,
805 });
806 }
807
808 self.auto_capture_telemetry(cx);
809
810 message_id
811 }
812
813 pub fn insert_message(
814 &mut self,
815 role: Role,
816 segments: Vec<MessageSegment>,
817 cx: &mut Context<Self>,
818 ) -> MessageId {
819 let id = self.next_message_id.post_inc();
820 self.messages.push(Message {
821 id,
822 role,
823 segments,
824 context: String::new(),
825 });
826 self.touch_updated_at();
827 cx.emit(ThreadEvent::MessageAdded(id));
828 id
829 }
830
831 pub fn edit_message(
832 &mut self,
833 id: MessageId,
834 new_role: Role,
835 new_segments: Vec<MessageSegment>,
836 cx: &mut Context<Self>,
837 ) -> bool {
838 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
839 return false;
840 };
841 message.role = new_role;
842 message.segments = new_segments;
843 self.touch_updated_at();
844 cx.emit(ThreadEvent::MessageEdited(id));
845 true
846 }
847
848 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
849 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
850 return false;
851 };
852 self.messages.remove(index);
853 self.context_by_message.remove(&id);
854 self.touch_updated_at();
855 cx.emit(ThreadEvent::MessageDeleted(id));
856 true
857 }
858
859 /// Returns the representation of this [`Thread`] in a textual form.
860 ///
861 /// This is the representation we use when attaching a thread as context to another thread.
862 pub fn text(&self) -> String {
863 let mut text = String::new();
864
865 for message in &self.messages {
866 text.push_str(match message.role {
867 language_model::Role::User => "User:",
868 language_model::Role::Assistant => "Assistant:",
869 language_model::Role::System => "System:",
870 });
871 text.push('\n');
872
873 for segment in &message.segments {
874 match segment {
875 MessageSegment::Text(content) => text.push_str(content),
876 MessageSegment::Thinking { text: content, .. } => {
877 text.push_str(&format!("<think>{}</think>", content))
878 }
879 MessageSegment::RedactedThinking(_) => {}
880 }
881 }
882 text.push('\n');
883 }
884
885 text
886 }
887
888 /// Serializes this thread into a format for storage or telemetry.
889 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
890 let initial_project_snapshot = self.initial_project_snapshot.clone();
891 cx.spawn(async move |this, cx| {
892 let initial_project_snapshot = initial_project_snapshot.await;
893 this.read_with(cx, |this, cx| SerializedThread {
894 version: SerializedThread::VERSION.to_string(),
895 summary: this.summary_or_default(),
896 updated_at: this.updated_at(),
897 messages: this
898 .messages()
899 .map(|message| SerializedMessage {
900 id: message.id,
901 role: message.role,
902 segments: message
903 .segments
904 .iter()
905 .map(|segment| match segment {
906 MessageSegment::Text(text) => {
907 SerializedMessageSegment::Text { text: text.clone() }
908 }
909 MessageSegment::Thinking { text, signature } => {
910 SerializedMessageSegment::Thinking {
911 text: text.clone(),
912 signature: signature.clone(),
913 }
914 }
915 MessageSegment::RedactedThinking(data) => {
916 SerializedMessageSegment::RedactedThinking {
917 data: data.clone(),
918 }
919 }
920 })
921 .collect(),
922 tool_uses: this
923 .tool_uses_for_message(message.id, cx)
924 .into_iter()
925 .map(|tool_use| SerializedToolUse {
926 id: tool_use.id,
927 name: tool_use.name,
928 input: tool_use.input,
929 })
930 .collect(),
931 tool_results: this
932 .tool_results_for_message(message.id)
933 .into_iter()
934 .map(|tool_result| SerializedToolResult {
935 tool_use_id: tool_result.tool_use_id.clone(),
936 is_error: tool_result.is_error,
937 content: tool_result.content.clone(),
938 })
939 .collect(),
940 context: message.context.clone(),
941 })
942 .collect(),
943 initial_project_snapshot,
944 cumulative_token_usage: this.cumulative_token_usage,
945 request_token_usage: this.request_token_usage.clone(),
946 detailed_summary_state: this.detailed_summary_state.clone(),
947 exceeded_window_error: this.exceeded_window_error.clone(),
948 })
949 })
950 }
951
952 pub fn send_to_model(
953 &mut self,
954 model: Arc<dyn LanguageModel>,
955 request_kind: RequestKind,
956 cx: &mut Context<Self>,
957 ) {
958 let mut request = self.to_completion_request(request_kind, cx);
959 if model.supports_tools() {
960 request.tools = {
961 let mut tools = Vec::new();
962 tools.extend(
963 self.tools()
964 .read(cx)
965 .enabled_tools(cx)
966 .into_iter()
967 .filter_map(|tool| {
968 // Skip tools that cannot be supported
969 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
970 Some(LanguageModelRequestTool {
971 name: tool.name(),
972 description: tool.description(),
973 input_schema,
974 })
975 }),
976 );
977
978 tools
979 };
980 }
981
982 self.stream_completion(request, model, cx);
983 }
984
985 pub fn used_tools_since_last_user_message(&self) -> bool {
986 for message in self.messages.iter().rev() {
987 if self.tool_use.message_has_tool_results(message.id) {
988 return true;
989 } else if message.role == Role::User {
990 return false;
991 }
992 }
993
994 false
995 }
996
997 pub fn to_completion_request(
998 &self,
999 request_kind: RequestKind,
1000 cx: &mut Context<Self>,
1001 ) -> LanguageModelRequest {
1002 let mut request = LanguageModelRequest {
1003 thread_id: Some(self.id.to_string()),
1004 prompt_id: Some(self.last_prompt_id.to_string()),
1005 messages: vec![],
1006 tools: Vec::new(),
1007 stop: Vec::new(),
1008 temperature: None,
1009 };
1010
1011 if let Some(project_context) = self.project_context.borrow().as_ref() {
1012 match self
1013 .prompt_builder
1014 .generate_assistant_system_prompt(project_context)
1015 {
1016 Err(err) => {
1017 let message = format!("{err:?}").into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024 Ok(system_prompt) => {
1025 request.messages.push(LanguageModelRequestMessage {
1026 role: Role::System,
1027 content: vec![MessageContent::Text(system_prompt)],
1028 cache: true,
1029 });
1030 }
1031 }
1032 } else {
1033 let message = "Context for system prompt unexpectedly not ready.".into();
1034 log::error!("{message}");
1035 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1036 header: "Error generating system prompt".into(),
1037 message,
1038 }));
1039 }
1040
1041 for message in &self.messages {
1042 let mut request_message = LanguageModelRequestMessage {
1043 role: message.role,
1044 content: Vec::new(),
1045 cache: false,
1046 };
1047
1048 match request_kind {
1049 RequestKind::Chat => {
1050 self.tool_use
1051 .attach_tool_results(message.id, &mut request_message);
1052 }
1053 RequestKind::Summarize => {
1054 // We don't care about tool use during summarization.
1055 if self.tool_use.message_has_tool_results(message.id) {
1056 continue;
1057 }
1058 }
1059 }
1060
1061 if !message.context.is_empty() {
1062 request_message
1063 .content
1064 .push(MessageContent::Text(message.context.to_string()));
1065 }
1066
1067 for segment in &message.segments {
1068 match segment {
1069 MessageSegment::Text(text) => {
1070 if !text.is_empty() {
1071 request_message
1072 .content
1073 .push(MessageContent::Text(text.into()));
1074 }
1075 }
1076 MessageSegment::Thinking { text, signature } => {
1077 if !text.is_empty() {
1078 request_message.content.push(MessageContent::Thinking {
1079 text: text.into(),
1080 signature: signature.clone(),
1081 });
1082 }
1083 }
1084 MessageSegment::RedactedThinking(data) => {
1085 request_message
1086 .content
1087 .push(MessageContent::RedactedThinking(data.clone()));
1088 }
1089 };
1090 }
1091
1092 match request_kind {
1093 RequestKind::Chat => {
1094 self.tool_use
1095 .attach_tool_uses(message.id, &mut request_message);
1096 }
1097 RequestKind::Summarize => {
1098 // We don't care about tool use during summarization.
1099 }
1100 };
1101
1102 request.messages.push(request_message);
1103 }
1104
1105 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1106 if let Some(last) = request.messages.last_mut() {
1107 last.cache = true;
1108 }
1109
1110 self.attached_tracked_files_state(&mut request.messages, cx);
1111
1112 request
1113 }
1114
1115 fn attached_tracked_files_state(
1116 &self,
1117 messages: &mut Vec<LanguageModelRequestMessage>,
1118 cx: &App,
1119 ) {
1120 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1121
1122 let mut stale_message = String::new();
1123
1124 let action_log = self.action_log.read(cx);
1125
1126 for stale_file in action_log.stale_buffers(cx) {
1127 let Some(file) = stale_file.read(cx).file() else {
1128 continue;
1129 };
1130
1131 if stale_message.is_empty() {
1132 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1133 }
1134
1135 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1136 }
1137
1138 let mut content = Vec::with_capacity(2);
1139
1140 if !stale_message.is_empty() {
1141 content.push(stale_message.into());
1142 }
1143
1144 if !content.is_empty() {
1145 let context_message = LanguageModelRequestMessage {
1146 role: Role::User,
1147 content,
1148 cache: false,
1149 };
1150
1151 messages.push(context_message);
1152 }
1153 }
1154
1155 pub fn stream_completion(
1156 &mut self,
1157 request: LanguageModelRequest,
1158 model: Arc<dyn LanguageModel>,
1159 cx: &mut Context<Self>,
1160 ) {
1161 let pending_completion_id = post_inc(&mut self.completion_count);
1162 let mut request_callback_parameters = if self.request_callback.is_some() {
1163 Some((request.clone(), Vec::new()))
1164 } else {
1165 None
1166 };
1167 let prompt_id = self.last_prompt_id.clone();
1168 let task = cx.spawn(async move |thread, cx| {
1169 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1170 let initial_token_usage =
1171 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1172 let stream_completion = async {
1173 let (mut events, usage) = stream_completion_future.await?;
1174
1175 let mut stop_reason = StopReason::EndTurn;
1176 let mut current_token_usage = TokenUsage::default();
1177
1178 if let Some(usage) = usage {
1179 thread
1180 .update(cx, |_thread, cx| {
1181 cx.emit(ThreadEvent::UsageUpdated(usage));
1182 })
1183 .ok();
1184 }
1185
1186 while let Some(event) = events.next().await {
1187 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1188 response_events
1189 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1190 }
1191
1192 let event = event?;
1193
1194 thread.update(cx, |thread, cx| {
1195 match event {
1196 LanguageModelCompletionEvent::StartMessage { .. } => {
1197 thread.insert_message(
1198 Role::Assistant,
1199 vec![MessageSegment::Text(String::new())],
1200 cx,
1201 );
1202 }
1203 LanguageModelCompletionEvent::Stop(reason) => {
1204 stop_reason = reason;
1205 }
1206 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1207 thread.update_token_usage_at_last_message(token_usage);
1208 thread.cumulative_token_usage = thread.cumulative_token_usage
1209 + token_usage
1210 - current_token_usage;
1211 current_token_usage = token_usage;
1212 }
1213 LanguageModelCompletionEvent::Text(chunk) => {
1214 if let Some(last_message) = thread.messages.last_mut() {
1215 if last_message.role == Role::Assistant {
1216 last_message.push_text(&chunk);
1217 cx.emit(ThreadEvent::StreamedAssistantText(
1218 last_message.id,
1219 chunk,
1220 ));
1221 } else {
1222 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1223 // of a new Assistant response.
1224 //
1225 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1226 // will result in duplicating the text of the chunk in the rendered Markdown.
1227 thread.insert_message(
1228 Role::Assistant,
1229 vec![MessageSegment::Text(chunk.to_string())],
1230 cx,
1231 );
1232 };
1233 }
1234 }
1235 LanguageModelCompletionEvent::Thinking {
1236 text: chunk,
1237 signature,
1238 } => {
1239 if let Some(last_message) = thread.messages.last_mut() {
1240 if last_message.role == Role::Assistant {
1241 last_message.push_thinking(&chunk, signature);
1242 cx.emit(ThreadEvent::StreamedAssistantThinking(
1243 last_message.id,
1244 chunk,
1245 ));
1246 } else {
1247 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1248 // of a new Assistant response.
1249 //
1250 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1251 // will result in duplicating the text of the chunk in the rendered Markdown.
1252 thread.insert_message(
1253 Role::Assistant,
1254 vec![MessageSegment::Thinking {
1255 text: chunk.to_string(),
1256 signature,
1257 }],
1258 cx,
1259 );
1260 };
1261 }
1262 }
1263 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1264 let last_assistant_message_id = thread
1265 .messages
1266 .iter_mut()
1267 .rfind(|message| message.role == Role::Assistant)
1268 .map(|message| message.id)
1269 .unwrap_or_else(|| {
1270 thread.insert_message(Role::Assistant, vec![], cx)
1271 });
1272
1273 thread.tool_use.request_tool_use(
1274 last_assistant_message_id,
1275 tool_use,
1276 cx,
1277 );
1278 }
1279 }
1280
1281 thread.touch_updated_at();
1282 cx.emit(ThreadEvent::StreamedCompletion);
1283 cx.notify();
1284
1285 thread.auto_capture_telemetry(cx);
1286 })?;
1287
1288 smol::future::yield_now().await;
1289 }
1290
1291 thread.update(cx, |thread, cx| {
1292 thread
1293 .pending_completions
1294 .retain(|completion| completion.id != pending_completion_id);
1295
1296 if thread.summary.is_none() && thread.messages.len() >= 2 {
1297 thread.summarize(cx);
1298 }
1299 })?;
1300
1301 anyhow::Ok(stop_reason)
1302 };
1303
1304 let result = stream_completion.await;
1305
1306 thread
1307 .update(cx, |thread, cx| {
1308 thread.finalize_pending_checkpoint(cx);
1309 match result.as_ref() {
1310 Ok(stop_reason) => match stop_reason {
1311 StopReason::ToolUse => {
1312 let tool_uses = thread.use_pending_tools(cx);
1313 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1314 }
1315 StopReason::EndTurn => {}
1316 StopReason::MaxTokens => {}
1317 },
1318 Err(error) => {
1319 if error.is::<PaymentRequiredError>() {
1320 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1321 } else if error.is::<MaxMonthlySpendReachedError>() {
1322 cx.emit(ThreadEvent::ShowError(
1323 ThreadError::MaxMonthlySpendReached,
1324 ));
1325 } else if let Some(error) =
1326 error.downcast_ref::<ModelRequestLimitReachedError>()
1327 {
1328 cx.emit(ThreadEvent::ShowError(
1329 ThreadError::ModelRequestLimitReached { plan: error.plan },
1330 ));
1331 } else if let Some(known_error) =
1332 error.downcast_ref::<LanguageModelKnownError>()
1333 {
1334 match known_error {
1335 LanguageModelKnownError::ContextWindowLimitExceeded {
1336 tokens,
1337 } => {
1338 thread.exceeded_window_error = Some(ExceededWindowError {
1339 model_id: model.id(),
1340 token_count: *tokens,
1341 });
1342 cx.notify();
1343 }
1344 }
1345 } else {
1346 let error_message = error
1347 .chain()
1348 .map(|err| err.to_string())
1349 .collect::<Vec<_>>()
1350 .join("\n");
1351 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1352 header: "Error interacting with language model".into(),
1353 message: SharedString::from(error_message.clone()),
1354 }));
1355 }
1356
1357 thread.cancel_last_completion(cx);
1358 }
1359 }
1360 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1361
1362 if let Some((request_callback, (request, response_events))) = thread
1363 .request_callback
1364 .as_mut()
1365 .zip(request_callback_parameters.as_ref())
1366 {
1367 request_callback(request, response_events);
1368 }
1369
1370 thread.auto_capture_telemetry(cx);
1371
1372 if let Ok(initial_usage) = initial_token_usage {
1373 let usage = thread.cumulative_token_usage - initial_usage;
1374
1375 telemetry::event!(
1376 "Assistant Thread Completion",
1377 thread_id = thread.id().to_string(),
1378 prompt_id = prompt_id,
1379 model = model.telemetry_id(),
1380 model_provider = model.provider_id().to_string(),
1381 input_tokens = usage.input_tokens,
1382 output_tokens = usage.output_tokens,
1383 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1384 cache_read_input_tokens = usage.cache_read_input_tokens,
1385 );
1386 }
1387 })
1388 .ok();
1389 });
1390
1391 self.pending_completions.push(PendingCompletion {
1392 id: pending_completion_id,
1393 _task: task,
1394 });
1395 }
1396
1397 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1398 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1399 return;
1400 };
1401
1402 if !model.provider.is_authenticated(cx) {
1403 return;
1404 }
1405
1406 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1407 request.messages.push(LanguageModelRequestMessage {
1408 role: Role::User,
1409 content: vec![
1410 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1411 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1412 If the conversation is about a specific subject, include it in the title. \
1413 Be descriptive. DO NOT speak in the first person."
1414 .into(),
1415 ],
1416 cache: false,
1417 });
1418
1419 self.pending_summary = cx.spawn(async move |this, cx| {
1420 async move {
1421 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1422 let (mut messages, usage) = stream.await?;
1423
1424 if let Some(usage) = usage {
1425 this.update(cx, |_thread, cx| {
1426 cx.emit(ThreadEvent::UsageUpdated(usage));
1427 })
1428 .ok();
1429 }
1430
1431 let mut new_summary = String::new();
1432 while let Some(message) = messages.stream.next().await {
1433 let text = message?;
1434 let mut lines = text.lines();
1435 new_summary.extend(lines.next());
1436
1437 // Stop if the LLM generated multiple lines.
1438 if lines.next().is_some() {
1439 break;
1440 }
1441 }
1442
1443 this.update(cx, |this, cx| {
1444 if !new_summary.is_empty() {
1445 this.summary = Some(new_summary.into());
1446 }
1447
1448 cx.emit(ThreadEvent::SummaryGenerated);
1449 })?;
1450
1451 anyhow::Ok(())
1452 }
1453 .log_err()
1454 .await
1455 });
1456 }
1457
1458 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1459 let last_message_id = self.messages.last().map(|message| message.id)?;
1460
1461 match &self.detailed_summary_state {
1462 DetailedSummaryState::Generating { message_id, .. }
1463 | DetailedSummaryState::Generated { message_id, .. }
1464 if *message_id == last_message_id =>
1465 {
1466 // Already up-to-date
1467 return None;
1468 }
1469 _ => {}
1470 }
1471
1472 let ConfiguredModel { model, provider } =
1473 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1474
1475 if !provider.is_authenticated(cx) {
1476 return None;
1477 }
1478
1479 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1480
1481 request.messages.push(LanguageModelRequestMessage {
1482 role: Role::User,
1483 content: vec![
1484 "Generate a detailed summary of this conversation. Include:\n\
1485 1. A brief overview of what was discussed\n\
1486 2. Key facts or information discovered\n\
1487 3. Outcomes or conclusions reached\n\
1488 4. Any action items or next steps if any\n\
1489 Format it in Markdown with headings and bullet points."
1490 .into(),
1491 ],
1492 cache: false,
1493 });
1494
1495 let task = cx.spawn(async move |thread, cx| {
1496 let stream = model.stream_completion_text(request, &cx);
1497 let Some(mut messages) = stream.await.log_err() else {
1498 thread
1499 .update(cx, |this, _cx| {
1500 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1501 })
1502 .log_err();
1503
1504 return;
1505 };
1506
1507 let mut new_detailed_summary = String::new();
1508
1509 while let Some(chunk) = messages.stream.next().await {
1510 if let Some(chunk) = chunk.log_err() {
1511 new_detailed_summary.push_str(&chunk);
1512 }
1513 }
1514
1515 thread
1516 .update(cx, |this, _cx| {
1517 this.detailed_summary_state = DetailedSummaryState::Generated {
1518 text: new_detailed_summary.into(),
1519 message_id: last_message_id,
1520 };
1521 })
1522 .log_err();
1523 });
1524
1525 self.detailed_summary_state = DetailedSummaryState::Generating {
1526 message_id: last_message_id,
1527 };
1528
1529 Some(task)
1530 }
1531
1532 pub fn is_generating_detailed_summary(&self) -> bool {
1533 matches!(
1534 self.detailed_summary_state,
1535 DetailedSummaryState::Generating { .. }
1536 )
1537 }
1538
1539 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1540 self.auto_capture_telemetry(cx);
1541 let request = self.to_completion_request(RequestKind::Chat, cx);
1542 let messages = Arc::new(request.messages);
1543 let pending_tool_uses = self
1544 .tool_use
1545 .pending_tool_uses()
1546 .into_iter()
1547 .filter(|tool_use| tool_use.status.is_idle())
1548 .cloned()
1549 .collect::<Vec<_>>();
1550
1551 for tool_use in pending_tool_uses.iter() {
1552 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1553 if tool.needs_confirmation(&tool_use.input, cx)
1554 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1555 {
1556 self.tool_use.confirm_tool_use(
1557 tool_use.id.clone(),
1558 tool_use.ui_text.clone(),
1559 tool_use.input.clone(),
1560 messages.clone(),
1561 tool,
1562 );
1563 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1564 } else {
1565 self.run_tool(
1566 tool_use.id.clone(),
1567 tool_use.ui_text.clone(),
1568 tool_use.input.clone(),
1569 &messages,
1570 tool,
1571 cx,
1572 );
1573 }
1574 }
1575 }
1576
1577 pending_tool_uses
1578 }
1579
1580 pub fn run_tool(
1581 &mut self,
1582 tool_use_id: LanguageModelToolUseId,
1583 ui_text: impl Into<SharedString>,
1584 input: serde_json::Value,
1585 messages: &[LanguageModelRequestMessage],
1586 tool: Arc<dyn Tool>,
1587 cx: &mut Context<Thread>,
1588 ) {
1589 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1590 self.tool_use
1591 .run_pending_tool(tool_use_id, ui_text.into(), task);
1592 }
1593
1594 fn spawn_tool_use(
1595 &mut self,
1596 tool_use_id: LanguageModelToolUseId,
1597 messages: &[LanguageModelRequestMessage],
1598 input: serde_json::Value,
1599 tool: Arc<dyn Tool>,
1600 cx: &mut Context<Thread>,
1601 ) -> Task<()> {
1602 let tool_name: Arc<str> = tool.name().into();
1603
1604 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1605 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1606 } else {
1607 tool.run(
1608 input,
1609 messages,
1610 self.project.clone(),
1611 self.action_log.clone(),
1612 cx,
1613 )
1614 };
1615
1616 // Store the card separately if it exists
1617 if let Some(card) = tool_result.card.clone() {
1618 self.tool_use
1619 .insert_tool_result_card(tool_use_id.clone(), card);
1620 }
1621
1622 cx.spawn({
1623 async move |thread: WeakEntity<Thread>, cx| {
1624 let output = tool_result.output.await;
1625
1626 thread
1627 .update(cx, |thread, cx| {
1628 let pending_tool_use = thread.tool_use.insert_tool_output(
1629 tool_use_id.clone(),
1630 tool_name,
1631 output,
1632 cx,
1633 );
1634 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1635 })
1636 .ok();
1637 }
1638 })
1639 }
1640
1641 fn tool_finished(
1642 &mut self,
1643 tool_use_id: LanguageModelToolUseId,
1644 pending_tool_use: Option<PendingToolUse>,
1645 canceled: bool,
1646 cx: &mut Context<Self>,
1647 ) {
1648 if self.all_tools_finished() {
1649 let model_registry = LanguageModelRegistry::read_global(cx);
1650 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1651 self.attach_tool_results(cx);
1652 if !canceled {
1653 self.send_to_model(model, RequestKind::Chat, cx);
1654 }
1655 }
1656 }
1657
1658 cx.emit(ThreadEvent::ToolFinished {
1659 tool_use_id,
1660 pending_tool_use,
1661 });
1662 }
1663
1664 /// Insert an empty message to be populated with tool results upon send.
1665 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1666 // TODO: Don't insert a dummy user message here. Ensure this works with the thinking model.
1667 // Insert a user message to contain the tool results.
1668 self.insert_user_message("Here are the tool results.", Vec::new(), None, cx);
1669 }
1670
1671 /// Cancels the last pending completion, if there are any pending.
1672 ///
1673 /// Returns whether a completion was canceled.
1674 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1675 let canceled = if self.pending_completions.pop().is_some() {
1676 true
1677 } else {
1678 let mut canceled = false;
1679 for pending_tool_use in self.tool_use.cancel_pending() {
1680 canceled = true;
1681 self.tool_finished(
1682 pending_tool_use.id.clone(),
1683 Some(pending_tool_use),
1684 true,
1685 cx,
1686 );
1687 }
1688 canceled
1689 };
1690 self.finalize_pending_checkpoint(cx);
1691 canceled
1692 }
1693
1694 pub fn feedback(&self) -> Option<ThreadFeedback> {
1695 self.feedback
1696 }
1697
1698 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1699 self.message_feedback.get(&message_id).copied()
1700 }
1701
1702 pub fn report_message_feedback(
1703 &mut self,
1704 message_id: MessageId,
1705 feedback: ThreadFeedback,
1706 cx: &mut Context<Self>,
1707 ) -> Task<Result<()>> {
1708 if self.message_feedback.get(&message_id) == Some(&feedback) {
1709 return Task::ready(Ok(()));
1710 }
1711
1712 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1713 let serialized_thread = self.serialize(cx);
1714 let thread_id = self.id().clone();
1715 let client = self.project.read(cx).client();
1716
1717 let enabled_tool_names: Vec<String> = self
1718 .tools()
1719 .read(cx)
1720 .enabled_tools(cx)
1721 .iter()
1722 .map(|tool| tool.name().to_string())
1723 .collect();
1724
1725 self.message_feedback.insert(message_id, feedback);
1726
1727 cx.notify();
1728
1729 let message_content = self
1730 .message(message_id)
1731 .map(|msg| msg.to_string())
1732 .unwrap_or_default();
1733
1734 cx.background_spawn(async move {
1735 let final_project_snapshot = final_project_snapshot.await;
1736 let serialized_thread = serialized_thread.await?;
1737 let thread_data =
1738 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1739
1740 let rating = match feedback {
1741 ThreadFeedback::Positive => "positive",
1742 ThreadFeedback::Negative => "negative",
1743 };
1744 telemetry::event!(
1745 "Assistant Thread Rated",
1746 rating,
1747 thread_id,
1748 enabled_tool_names,
1749 message_id = message_id.0,
1750 message_content,
1751 thread_data,
1752 final_project_snapshot
1753 );
1754 client.telemetry().flush_events();
1755
1756 Ok(())
1757 })
1758 }
1759
1760 pub fn report_feedback(
1761 &mut self,
1762 feedback: ThreadFeedback,
1763 cx: &mut Context<Self>,
1764 ) -> Task<Result<()>> {
1765 let last_assistant_message_id = self
1766 .messages
1767 .iter()
1768 .rev()
1769 .find(|msg| msg.role == Role::Assistant)
1770 .map(|msg| msg.id);
1771
1772 if let Some(message_id) = last_assistant_message_id {
1773 self.report_message_feedback(message_id, feedback, cx)
1774 } else {
1775 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1776 let serialized_thread = self.serialize(cx);
1777 let thread_id = self.id().clone();
1778 let client = self.project.read(cx).client();
1779 self.feedback = Some(feedback);
1780 cx.notify();
1781
1782 cx.background_spawn(async move {
1783 let final_project_snapshot = final_project_snapshot.await;
1784 let serialized_thread = serialized_thread.await?;
1785 let thread_data = serde_json::to_value(serialized_thread)
1786 .unwrap_or_else(|_| serde_json::Value::Null);
1787
1788 let rating = match feedback {
1789 ThreadFeedback::Positive => "positive",
1790 ThreadFeedback::Negative => "negative",
1791 };
1792 telemetry::event!(
1793 "Assistant Thread Rated",
1794 rating,
1795 thread_id,
1796 thread_data,
1797 final_project_snapshot
1798 );
1799 client.telemetry().flush_events();
1800
1801 Ok(())
1802 })
1803 }
1804 }
1805
1806 /// Create a snapshot of the current project state including git information and unsaved buffers.
1807 fn project_snapshot(
1808 project: Entity<Project>,
1809 cx: &mut Context<Self>,
1810 ) -> Task<Arc<ProjectSnapshot>> {
1811 let git_store = project.read(cx).git_store().clone();
1812 let worktree_snapshots: Vec<_> = project
1813 .read(cx)
1814 .visible_worktrees(cx)
1815 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1816 .collect();
1817
1818 cx.spawn(async move |_, cx| {
1819 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1820
1821 let mut unsaved_buffers = Vec::new();
1822 cx.update(|app_cx| {
1823 let buffer_store = project.read(app_cx).buffer_store();
1824 for buffer_handle in buffer_store.read(app_cx).buffers() {
1825 let buffer = buffer_handle.read(app_cx);
1826 if buffer.is_dirty() {
1827 if let Some(file) = buffer.file() {
1828 let path = file.path().to_string_lossy().to_string();
1829 unsaved_buffers.push(path);
1830 }
1831 }
1832 }
1833 })
1834 .ok();
1835
1836 Arc::new(ProjectSnapshot {
1837 worktree_snapshots,
1838 unsaved_buffer_paths: unsaved_buffers,
1839 timestamp: Utc::now(),
1840 })
1841 })
1842 }
1843
1844 fn worktree_snapshot(
1845 worktree: Entity<project::Worktree>,
1846 git_store: Entity<GitStore>,
1847 cx: &App,
1848 ) -> Task<WorktreeSnapshot> {
1849 cx.spawn(async move |cx| {
1850 // Get worktree path and snapshot
1851 let worktree_info = cx.update(|app_cx| {
1852 let worktree = worktree.read(app_cx);
1853 let path = worktree.abs_path().to_string_lossy().to_string();
1854 let snapshot = worktree.snapshot();
1855 (path, snapshot)
1856 });
1857
1858 let Ok((worktree_path, _snapshot)) = worktree_info else {
1859 return WorktreeSnapshot {
1860 worktree_path: String::new(),
1861 git_state: None,
1862 };
1863 };
1864
1865 let git_state = git_store
1866 .update(cx, |git_store, cx| {
1867 git_store
1868 .repositories()
1869 .values()
1870 .find(|repo| {
1871 repo.read(cx)
1872 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1873 .is_some()
1874 })
1875 .cloned()
1876 })
1877 .ok()
1878 .flatten()
1879 .map(|repo| {
1880 repo.update(cx, |repo, _| {
1881 let current_branch =
1882 repo.branch.as_ref().map(|branch| branch.name.to_string());
1883 repo.send_job(None, |state, _| async move {
1884 let RepositoryState::Local { backend, .. } = state else {
1885 return GitState {
1886 remote_url: None,
1887 head_sha: None,
1888 current_branch,
1889 diff: None,
1890 };
1891 };
1892
1893 let remote_url = backend.remote_url("origin");
1894 let head_sha = backend.head_sha();
1895 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1896
1897 GitState {
1898 remote_url,
1899 head_sha,
1900 current_branch,
1901 diff,
1902 }
1903 })
1904 })
1905 });
1906
1907 let git_state = match git_state {
1908 Some(git_state) => match git_state.ok() {
1909 Some(git_state) => git_state.await.ok(),
1910 None => None,
1911 },
1912 None => None,
1913 };
1914
1915 WorktreeSnapshot {
1916 worktree_path,
1917 git_state,
1918 }
1919 })
1920 }
1921
1922 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1923 let mut markdown = Vec::new();
1924
1925 if let Some(summary) = self.summary() {
1926 writeln!(markdown, "# {summary}\n")?;
1927 };
1928
1929 for message in self.messages() {
1930 writeln!(
1931 markdown,
1932 "## {role}\n",
1933 role = match message.role {
1934 Role::User => "User",
1935 Role::Assistant => "Assistant",
1936 Role::System => "System",
1937 }
1938 )?;
1939
1940 if !message.context.is_empty() {
1941 writeln!(markdown, "{}", message.context)?;
1942 }
1943
1944 for segment in &message.segments {
1945 match segment {
1946 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1947 MessageSegment::Thinking { text, .. } => {
1948 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1949 }
1950 MessageSegment::RedactedThinking(_) => {}
1951 }
1952 }
1953
1954 for tool_use in self.tool_uses_for_message(message.id, cx) {
1955 writeln!(
1956 markdown,
1957 "**Use Tool: {} ({})**",
1958 tool_use.name, tool_use.id
1959 )?;
1960 writeln!(markdown, "```json")?;
1961 writeln!(
1962 markdown,
1963 "{}",
1964 serde_json::to_string_pretty(&tool_use.input)?
1965 )?;
1966 writeln!(markdown, "```")?;
1967 }
1968
1969 for tool_result in self.tool_results_for_message(message.id) {
1970 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1971 if tool_result.is_error {
1972 write!(markdown, " (Error)")?;
1973 }
1974
1975 writeln!(markdown, "**\n")?;
1976 writeln!(markdown, "{}", tool_result.content)?;
1977 }
1978 }
1979
1980 Ok(String::from_utf8_lossy(&markdown).to_string())
1981 }
1982
1983 pub fn keep_edits_in_range(
1984 &mut self,
1985 buffer: Entity<language::Buffer>,
1986 buffer_range: Range<language::Anchor>,
1987 cx: &mut Context<Self>,
1988 ) {
1989 self.action_log.update(cx, |action_log, cx| {
1990 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1991 });
1992 }
1993
1994 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1995 self.action_log
1996 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1997 }
1998
1999 pub fn reject_edits_in_ranges(
2000 &mut self,
2001 buffer: Entity<language::Buffer>,
2002 buffer_ranges: Vec<Range<language::Anchor>>,
2003 cx: &mut Context<Self>,
2004 ) -> Task<Result<()>> {
2005 self.action_log.update(cx, |action_log, cx| {
2006 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2007 })
2008 }
2009
2010 pub fn action_log(&self) -> &Entity<ActionLog> {
2011 &self.action_log
2012 }
2013
2014 pub fn project(&self) -> &Entity<Project> {
2015 &self.project
2016 }
2017
2018 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2019 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2020 return;
2021 }
2022
2023 let now = Instant::now();
2024 if let Some(last) = self.last_auto_capture_at {
2025 if now.duration_since(last).as_secs() < 10 {
2026 return;
2027 }
2028 }
2029
2030 self.last_auto_capture_at = Some(now);
2031
2032 let thread_id = self.id().clone();
2033 let github_login = self
2034 .project
2035 .read(cx)
2036 .user_store()
2037 .read(cx)
2038 .current_user()
2039 .map(|user| user.github_login.clone());
2040 let client = self.project.read(cx).client().clone();
2041 let serialize_task = self.serialize(cx);
2042
2043 cx.background_executor()
2044 .spawn(async move {
2045 if let Ok(serialized_thread) = serialize_task.await {
2046 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2047 telemetry::event!(
2048 "Agent Thread Auto-Captured",
2049 thread_id = thread_id.to_string(),
2050 thread_data = thread_data,
2051 auto_capture_reason = "tracked_user",
2052 github_login = github_login
2053 );
2054
2055 client.telemetry().flush_events();
2056 }
2057 }
2058 })
2059 .detach();
2060 }
2061
2062 pub fn cumulative_token_usage(&self) -> TokenUsage {
2063 self.cumulative_token_usage
2064 }
2065
2066 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2067 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2068 return TotalTokenUsage::default();
2069 };
2070
2071 let max = model.model.max_token_count();
2072
2073 let index = self
2074 .messages
2075 .iter()
2076 .position(|msg| msg.id == message_id)
2077 .unwrap_or(0);
2078
2079 if index == 0 {
2080 return TotalTokenUsage { total: 0, max };
2081 }
2082
2083 let token_usage = &self
2084 .request_token_usage
2085 .get(index - 1)
2086 .cloned()
2087 .unwrap_or_default();
2088
2089 TotalTokenUsage {
2090 total: token_usage.total_tokens() as usize,
2091 max,
2092 }
2093 }
2094
2095 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2096 let model_registry = LanguageModelRegistry::read_global(cx);
2097 let Some(model) = model_registry.default_model() else {
2098 return TotalTokenUsage::default();
2099 };
2100
2101 let max = model.model.max_token_count();
2102
2103 if let Some(exceeded_error) = &self.exceeded_window_error {
2104 if model.model.id() == exceeded_error.model_id {
2105 return TotalTokenUsage {
2106 total: exceeded_error.token_count,
2107 max,
2108 };
2109 }
2110 }
2111
2112 let total = self
2113 .token_usage_at_last_message()
2114 .unwrap_or_default()
2115 .total_tokens() as usize;
2116
2117 TotalTokenUsage { total, max }
2118 }
2119
2120 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2121 self.request_token_usage
2122 .get(self.messages.len().saturating_sub(1))
2123 .or_else(|| self.request_token_usage.last())
2124 .cloned()
2125 }
2126
2127 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2128 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2129 self.request_token_usage
2130 .resize(self.messages.len(), placeholder);
2131
2132 if let Some(last) = self.request_token_usage.last_mut() {
2133 *last = token_usage;
2134 }
2135 }
2136
2137 pub fn deny_tool_use(
2138 &mut self,
2139 tool_use_id: LanguageModelToolUseId,
2140 tool_name: Arc<str>,
2141 cx: &mut Context<Self>,
2142 ) {
2143 let err = Err(anyhow::anyhow!(
2144 "Permission to run tool action denied by user"
2145 ));
2146
2147 self.tool_use
2148 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2149 self.tool_finished(tool_use_id.clone(), None, true, cx);
2150 }
2151}
2152
2153#[derive(Debug, Clone, Error)]
2154pub enum ThreadError {
2155 #[error("Payment required")]
2156 PaymentRequired,
2157 #[error("Max monthly spend reached")]
2158 MaxMonthlySpendReached,
2159 #[error("Model request limit reached")]
2160 ModelRequestLimitReached { plan: Plan },
2161 #[error("Message {header}: {message}")]
2162 Message {
2163 header: SharedString,
2164 message: SharedString,
2165 },
2166}
2167
2168#[derive(Debug, Clone)]
2169pub enum ThreadEvent {
2170 ShowError(ThreadError),
2171 UsageUpdated(RequestUsage),
2172 StreamedCompletion,
2173 StreamedAssistantText(MessageId, String),
2174 StreamedAssistantThinking(MessageId, String),
2175 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2176 MessageAdded(MessageId),
2177 MessageEdited(MessageId),
2178 MessageDeleted(MessageId),
2179 SummaryGenerated,
2180 SummaryChanged,
2181 UsePendingTools {
2182 tool_uses: Vec<PendingToolUse>,
2183 },
2184 ToolFinished {
2185 #[allow(unused)]
2186 tool_use_id: LanguageModelToolUseId,
2187 /// The pending tool use that corresponds to this tool.
2188 pending_tool_use: Option<PendingToolUse>,
2189 },
2190 CheckpointChanged,
2191 ToolConfirmationNeeded,
2192}
2193
2194impl EventEmitter<ThreadEvent> for Thread {}
2195
2196struct PendingCompletion {
2197 id: usize,
2198 _task: Task<()>,
2199}
2200
2201#[cfg(test)]
2202mod tests {
2203 use super::*;
2204 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2205 use assistant_settings::AssistantSettings;
2206 use context_server::ContextServerSettings;
2207 use editor::EditorSettings;
2208 use gpui::TestAppContext;
2209 use project::{FakeFs, Project};
2210 use prompt_store::PromptBuilder;
2211 use serde_json::json;
2212 use settings::{Settings, SettingsStore};
2213 use std::sync::Arc;
2214 use theme::ThemeSettings;
2215 use util::path;
2216 use workspace::Workspace;
2217
2218 #[gpui::test]
2219 async fn test_message_with_context(cx: &mut TestAppContext) {
2220 init_test_settings(cx);
2221
2222 let project = create_test_project(
2223 cx,
2224 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2225 )
2226 .await;
2227
2228 let (_workspace, _thread_store, thread, context_store) =
2229 setup_test_environment(cx, project.clone()).await;
2230
2231 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2232 .await
2233 .unwrap();
2234
2235 let context =
2236 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2237
2238 // Insert user message with context
2239 let message_id = thread.update(cx, |thread, cx| {
2240 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2241 });
2242
2243 // Check content and context in message object
2244 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2245
2246 // Use different path format strings based on platform for the test
2247 #[cfg(windows)]
2248 let path_part = r"test\code.rs";
2249 #[cfg(not(windows))]
2250 let path_part = "test/code.rs";
2251
2252 let expected_context = format!(
2253 r#"
2254<context>
2255The following items were attached by the user. You don't need to use other tools to read them.
2256
2257<files>
2258```rs {path_part}
2259fn main() {{
2260 println!("Hello, world!");
2261}}
2262```
2263</files>
2264</context>
2265"#
2266 );
2267
2268 assert_eq!(message.role, Role::User);
2269 assert_eq!(message.segments.len(), 1);
2270 assert_eq!(
2271 message.segments[0],
2272 MessageSegment::Text("Please explain this code".to_string())
2273 );
2274 assert_eq!(message.context, expected_context);
2275
2276 // Check message in request
2277 let request = thread.update(cx, |thread, cx| {
2278 thread.to_completion_request(RequestKind::Chat, cx)
2279 });
2280
2281 assert_eq!(request.messages.len(), 2);
2282 let expected_full_message = format!("{}Please explain this code", expected_context);
2283 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2284 }
2285
2286 #[gpui::test]
2287 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2288 init_test_settings(cx);
2289
2290 let project = create_test_project(
2291 cx,
2292 json!({
2293 "file1.rs": "fn function1() {}\n",
2294 "file2.rs": "fn function2() {}\n",
2295 "file3.rs": "fn function3() {}\n",
2296 }),
2297 )
2298 .await;
2299
2300 let (_, _thread_store, thread, context_store) =
2301 setup_test_environment(cx, project.clone()).await;
2302
2303 // Open files individually
2304 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2305 .await
2306 .unwrap();
2307 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2308 .await
2309 .unwrap();
2310 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2311 .await
2312 .unwrap();
2313
2314 // Get the context objects
2315 let contexts = context_store.update(cx, |store, _| store.context().clone());
2316 assert_eq!(contexts.len(), 3);
2317
2318 // First message with context 1
2319 let message1_id = thread.update(cx, |thread, cx| {
2320 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2321 });
2322
2323 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2324 let message2_id = thread.update(cx, |thread, cx| {
2325 thread.insert_user_message(
2326 "Message 2",
2327 vec![contexts[0].clone(), contexts[1].clone()],
2328 None,
2329 cx,
2330 )
2331 });
2332
2333 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2334 let message3_id = thread.update(cx, |thread, cx| {
2335 thread.insert_user_message(
2336 "Message 3",
2337 vec![
2338 contexts[0].clone(),
2339 contexts[1].clone(),
2340 contexts[2].clone(),
2341 ],
2342 None,
2343 cx,
2344 )
2345 });
2346
2347 // Check what contexts are included in each message
2348 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2349 (
2350 thread.message(message1_id).unwrap().clone(),
2351 thread.message(message2_id).unwrap().clone(),
2352 thread.message(message3_id).unwrap().clone(),
2353 )
2354 });
2355
2356 // First message should include context 1
2357 assert!(message1.context.contains("file1.rs"));
2358
2359 // Second message should include only context 2 (not 1)
2360 assert!(!message2.context.contains("file1.rs"));
2361 assert!(message2.context.contains("file2.rs"));
2362
2363 // Third message should include only context 3 (not 1 or 2)
2364 assert!(!message3.context.contains("file1.rs"));
2365 assert!(!message3.context.contains("file2.rs"));
2366 assert!(message3.context.contains("file3.rs"));
2367
2368 // Check entire request to make sure all contexts are properly included
2369 let request = thread.update(cx, |thread, cx| {
2370 thread.to_completion_request(RequestKind::Chat, cx)
2371 });
2372
2373 // The request should contain all 3 messages
2374 assert_eq!(request.messages.len(), 4);
2375
2376 // Check that the contexts are properly formatted in each message
2377 assert!(request.messages[1].string_contents().contains("file1.rs"));
2378 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2379 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2380
2381 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2382 assert!(request.messages[2].string_contents().contains("file2.rs"));
2383 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2384
2385 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2386 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2387 assert!(request.messages[3].string_contents().contains("file3.rs"));
2388 }
2389
2390 #[gpui::test]
2391 async fn test_message_without_files(cx: &mut TestAppContext) {
2392 init_test_settings(cx);
2393
2394 let project = create_test_project(
2395 cx,
2396 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2397 )
2398 .await;
2399
2400 let (_, _thread_store, thread, _context_store) =
2401 setup_test_environment(cx, project.clone()).await;
2402
2403 // Insert user message without any context (empty context vector)
2404 let message_id = thread.update(cx, |thread, cx| {
2405 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2406 });
2407
2408 // Check content and context in message object
2409 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2410
2411 // Context should be empty when no files are included
2412 assert_eq!(message.role, Role::User);
2413 assert_eq!(message.segments.len(), 1);
2414 assert_eq!(
2415 message.segments[0],
2416 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2417 );
2418 assert_eq!(message.context, "");
2419
2420 // Check message in request
2421 let request = thread.update(cx, |thread, cx| {
2422 thread.to_completion_request(RequestKind::Chat, cx)
2423 });
2424
2425 assert_eq!(request.messages.len(), 2);
2426 assert_eq!(
2427 request.messages[1].string_contents(),
2428 "What is the best way to learn Rust?"
2429 );
2430
2431 // Add second message, also without context
2432 let message2_id = thread.update(cx, |thread, cx| {
2433 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2434 });
2435
2436 let message2 =
2437 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2438 assert_eq!(message2.context, "");
2439
2440 // Check that both messages appear in the request
2441 let request = thread.update(cx, |thread, cx| {
2442 thread.to_completion_request(RequestKind::Chat, cx)
2443 });
2444
2445 assert_eq!(request.messages.len(), 3);
2446 assert_eq!(
2447 request.messages[1].string_contents(),
2448 "What is the best way to learn Rust?"
2449 );
2450 assert_eq!(
2451 request.messages[2].string_contents(),
2452 "Are there any good books?"
2453 );
2454 }
2455
2456 #[gpui::test]
2457 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2458 init_test_settings(cx);
2459
2460 let project = create_test_project(
2461 cx,
2462 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2463 )
2464 .await;
2465
2466 let (_workspace, _thread_store, thread, context_store) =
2467 setup_test_environment(cx, project.clone()).await;
2468
2469 // Open buffer and add it to context
2470 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2471 .await
2472 .unwrap();
2473
2474 let context =
2475 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2476
2477 // Insert user message with the buffer as context
2478 thread.update(cx, |thread, cx| {
2479 thread.insert_user_message("Explain this code", vec![context], None, cx)
2480 });
2481
2482 // Create a request and check that it doesn't have a stale buffer warning yet
2483 let initial_request = thread.update(cx, |thread, cx| {
2484 thread.to_completion_request(RequestKind::Chat, cx)
2485 });
2486
2487 // Make sure we don't have a stale file warning yet
2488 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2489 msg.string_contents()
2490 .contains("These files changed since last read:")
2491 });
2492 assert!(
2493 !has_stale_warning,
2494 "Should not have stale buffer warning before buffer is modified"
2495 );
2496
2497 // Modify the buffer
2498 buffer.update(cx, |buffer, cx| {
2499 // Find a position at the end of line 1
2500 buffer.edit(
2501 [(1..1, "\n println!(\"Added a new line\");\n")],
2502 None,
2503 cx,
2504 );
2505 });
2506
2507 // Insert another user message without context
2508 thread.update(cx, |thread, cx| {
2509 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2510 });
2511
2512 // Create a new request and check for the stale buffer warning
2513 let new_request = thread.update(cx, |thread, cx| {
2514 thread.to_completion_request(RequestKind::Chat, cx)
2515 });
2516
2517 // We should have a stale file warning as the last message
2518 let last_message = new_request
2519 .messages
2520 .last()
2521 .expect("Request should have messages");
2522
2523 // The last message should be the stale buffer notification
2524 assert_eq!(last_message.role, Role::User);
2525
2526 // Check the exact content of the message
2527 let expected_content = "These files changed since last read:\n- code.rs\n";
2528 assert_eq!(
2529 last_message.string_contents(),
2530 expected_content,
2531 "Last message should be exactly the stale buffer notification"
2532 );
2533 }
2534
2535 fn init_test_settings(cx: &mut TestAppContext) {
2536 cx.update(|cx| {
2537 let settings_store = SettingsStore::test(cx);
2538 cx.set_global(settings_store);
2539 language::init(cx);
2540 Project::init_settings(cx);
2541 AssistantSettings::register(cx);
2542 prompt_store::init(cx);
2543 thread_store::init(cx);
2544 workspace::init_settings(cx);
2545 ThemeSettings::register(cx);
2546 ContextServerSettings::register(cx);
2547 EditorSettings::register(cx);
2548 });
2549 }
2550
2551 // Helper to create a test project with test files
2552 async fn create_test_project(
2553 cx: &mut TestAppContext,
2554 files: serde_json::Value,
2555 ) -> Entity<Project> {
2556 let fs = FakeFs::new(cx.executor());
2557 fs.insert_tree(path!("/test"), files).await;
2558 Project::test(fs, [path!("/test").as_ref()], cx).await
2559 }
2560
2561 async fn setup_test_environment(
2562 cx: &mut TestAppContext,
2563 project: Entity<Project>,
2564 ) -> (
2565 Entity<Workspace>,
2566 Entity<ThreadStore>,
2567 Entity<Thread>,
2568 Entity<ContextStore>,
2569 ) {
2570 let (workspace, cx) =
2571 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2572
2573 let thread_store = cx
2574 .update(|_, cx| {
2575 ThreadStore::load(
2576 project.clone(),
2577 cx.new(|_| ToolWorkingSet::default()),
2578 Arc::new(PromptBuilder::new(None).unwrap()),
2579 cx,
2580 )
2581 })
2582 .await
2583 .unwrap();
2584
2585 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2586 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2587
2588 (workspace, thread_store, thread, context_store)
2589 }
2590
2591 async fn add_file_to_context(
2592 project: &Entity<Project>,
2593 context_store: &Entity<ContextStore>,
2594 path: &str,
2595 cx: &mut TestAppContext,
2596 ) -> Result<Entity<language::Buffer>> {
2597 let buffer_path = project
2598 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2599 .unwrap();
2600
2601 let buffer = project
2602 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2603 .await
2604 .unwrap();
2605
2606 context_store
2607 .update(cx, |store, cx| {
2608 store.add_file_from_buffer(buffer.clone(), cx)
2609 })
2610 .await?;
2611
2612 Ok(buffer)
2613 }
2614}