1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// A message in a [`Thread`].
101#[derive(Debug, Clone)]
102pub struct Message {
103 pub id: MessageId,
104 pub role: Role,
105 pub segments: Vec<MessageSegment>,
106 pub context: String,
107}
108
109impl Message {
110 /// Returns whether the message contains any meaningful text that should be displayed
111 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
112 pub fn should_display_content(&self) -> bool {
113 self.segments.iter().all(|segment| segment.should_display())
114 }
115
116 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
117 if let Some(MessageSegment::Thinking {
118 text: segment,
119 signature: current_signature,
120 }) = self.segments.last_mut()
121 {
122 if let Some(signature) = signature {
123 *current_signature = Some(signature);
124 }
125 segment.push_str(text);
126 } else {
127 self.segments.push(MessageSegment::Thinking {
128 text: text.to_string(),
129 signature,
130 });
131 }
132 }
133
134 pub fn push_text(&mut self, text: &str) {
135 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Text(text.to_string()));
139 }
140 }
141
142 pub fn to_string(&self) -> String {
143 let mut result = String::new();
144
145 if !self.context.is_empty() {
146 result.push_str(&self.context);
147 }
148
149 for segment in &self.segments {
150 match segment {
151 MessageSegment::Text(text) => result.push_str(text),
152 MessageSegment::Thinking { text, .. } => {
153 result.push_str("<think>\n");
154 result.push_str(text);
155 result.push_str("\n</think>");
156 }
157 MessageSegment::RedactedThinking(_) => {}
158 }
159 }
160
161 result
162 }
163}
164
165#[derive(Debug, Clone, PartialEq, Eq)]
166pub enum MessageSegment {
167 Text(String),
168 Thinking {
169 text: String,
170 signature: Option<String>,
171 },
172 RedactedThinking(Vec<u8>),
173}
174
175impl MessageSegment {
176 pub fn should_display(&self) -> bool {
177 // We add USING_TOOL_MARKER when making a request that includes tool uses
178 // without non-whitespace text around them, and this can cause the model
179 // to mimic the pattern, so we consider those segments not displayable.
180 match self {
181 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
182 Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
183 Self::RedactedThinking(_) => false,
184 }
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ProjectSnapshot {
190 pub worktree_snapshots: Vec<WorktreeSnapshot>,
191 pub unsaved_buffer_paths: Vec<String>,
192 pub timestamp: DateTime<Utc>,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct WorktreeSnapshot {
197 pub worktree_path: String,
198 pub git_state: Option<GitState>,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct GitState {
203 pub remote_url: Option<String>,
204 pub head_sha: Option<String>,
205 pub current_branch: Option<String>,
206 pub diff: Option<String>,
207}
208
209#[derive(Clone)]
210pub struct ThreadCheckpoint {
211 message_id: MessageId,
212 git_checkpoint: GitStoreCheckpoint,
213}
214
215#[derive(Copy, Clone, Debug, PartialEq, Eq)]
216pub enum ThreadFeedback {
217 Positive,
218 Negative,
219}
220
221pub enum LastRestoreCheckpoint {
222 Pending {
223 message_id: MessageId,
224 },
225 Error {
226 message_id: MessageId,
227 error: String,
228 },
229}
230
231impl LastRestoreCheckpoint {
232 pub fn message_id(&self) -> MessageId {
233 match self {
234 LastRestoreCheckpoint::Pending { message_id } => *message_id,
235 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
236 }
237 }
238}
239
240#[derive(Clone, Debug, Default, Serialize, Deserialize)]
241pub enum DetailedSummaryState {
242 #[default]
243 NotGenerated,
244 Generating {
245 message_id: MessageId,
246 },
247 Generated {
248 text: SharedString,
249 message_id: MessageId,
250 },
251}
252
253#[derive(Default)]
254pub struct TotalTokenUsage {
255 pub total: usize,
256 pub max: usize,
257}
258
259impl TotalTokenUsage {
260 pub fn ratio(&self) -> TokenUsageRatio {
261 #[cfg(debug_assertions)]
262 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
263 .unwrap_or("0.8".to_string())
264 .parse()
265 .unwrap();
266 #[cfg(not(debug_assertions))]
267 let warning_threshold: f32 = 0.8;
268
269 if self.total >= self.max {
270 TokenUsageRatio::Exceeded
271 } else if self.total as f32 / self.max as f32 >= warning_threshold {
272 TokenUsageRatio::Warning
273 } else {
274 TokenUsageRatio::Normal
275 }
276 }
277
278 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
279 TotalTokenUsage {
280 total: self.total + tokens,
281 max: self.max,
282 }
283 }
284}
285
286#[derive(Debug, Default, PartialEq, Eq)]
287pub enum TokenUsageRatio {
288 #[default]
289 Normal,
290 Warning,
291 Exceeded,
292}
293
294/// A thread of conversation with the LLM.
295pub struct Thread {
296 id: ThreadId,
297 updated_at: DateTime<Utc>,
298 summary: Option<SharedString>,
299 pending_summary: Task<Option<()>>,
300 detailed_summary_state: DetailedSummaryState,
301 messages: Vec<Message>,
302 next_message_id: MessageId,
303 last_prompt_id: PromptId,
304 context: BTreeMap<ContextId, AssistantContext>,
305 context_by_message: HashMap<MessageId, Vec<ContextId>>,
306 project_context: SharedProjectContext,
307 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
308 completion_count: usize,
309 pending_completions: Vec<PendingCompletion>,
310 project: Entity<Project>,
311 prompt_builder: Arc<PromptBuilder>,
312 tools: Entity<ToolWorkingSet>,
313 tool_use: ToolUseState,
314 action_log: Entity<ActionLog>,
315 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
316 pending_checkpoint: Option<ThreadCheckpoint>,
317 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
318 request_token_usage: Vec<TokenUsage>,
319 cumulative_token_usage: TokenUsage,
320 exceeded_window_error: Option<ExceededWindowError>,
321 feedback: Option<ThreadFeedback>,
322 message_feedback: HashMap<MessageId, ThreadFeedback>,
323 last_auto_capture_at: Option<Instant>,
324 request_callback: Option<
325 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
326 >,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct ExceededWindowError {
331 /// Model used when last message exceeded context window
332 model_id: LanguageModelId,
333 /// Token count including last message
334 token_count: usize,
335}
336
337impl Thread {
338 pub fn new(
339 project: Entity<Project>,
340 tools: Entity<ToolWorkingSet>,
341 prompt_builder: Arc<PromptBuilder>,
342 system_prompt: SharedProjectContext,
343 cx: &mut Context<Self>,
344 ) -> Self {
345 Self {
346 id: ThreadId::new(),
347 updated_at: Utc::now(),
348 summary: None,
349 pending_summary: Task::ready(None),
350 detailed_summary_state: DetailedSummaryState::NotGenerated,
351 messages: Vec::new(),
352 next_message_id: MessageId(0),
353 last_prompt_id: PromptId::new(),
354 context: BTreeMap::default(),
355 context_by_message: HashMap::default(),
356 project_context: system_prompt,
357 checkpoints_by_message: HashMap::default(),
358 completion_count: 0,
359 pending_completions: Vec::new(),
360 project: project.clone(),
361 prompt_builder,
362 tools: tools.clone(),
363 last_restore_checkpoint: None,
364 pending_checkpoint: None,
365 tool_use: ToolUseState::new(tools.clone()),
366 action_log: cx.new(|_| ActionLog::new(project.clone())),
367 initial_project_snapshot: {
368 let project_snapshot = Self::project_snapshot(project, cx);
369 cx.foreground_executor()
370 .spawn(async move { Some(project_snapshot.await) })
371 .shared()
372 },
373 request_token_usage: Vec::new(),
374 cumulative_token_usage: TokenUsage::default(),
375 exceeded_window_error: None,
376 feedback: None,
377 message_feedback: HashMap::default(),
378 last_auto_capture_at: None,
379 request_callback: None,
380 }
381 }
382
383 pub fn deserialize(
384 id: ThreadId,
385 serialized: SerializedThread,
386 project: Entity<Project>,
387 tools: Entity<ToolWorkingSet>,
388 prompt_builder: Arc<PromptBuilder>,
389 project_context: SharedProjectContext,
390 cx: &mut Context<Self>,
391 ) -> Self {
392 let next_message_id = MessageId(
393 serialized
394 .messages
395 .last()
396 .map(|message| message.id.0 + 1)
397 .unwrap_or(0),
398 );
399 let tool_use =
400 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
401
402 Self {
403 id,
404 updated_at: serialized.updated_at,
405 summary: Some(serialized.summary),
406 pending_summary: Task::ready(None),
407 detailed_summary_state: serialized.detailed_summary_state,
408 messages: serialized
409 .messages
410 .into_iter()
411 .map(|message| Message {
412 id: message.id,
413 role: message.role,
414 segments: message
415 .segments
416 .into_iter()
417 .map(|segment| match segment {
418 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
419 SerializedMessageSegment::Thinking { text, signature } => {
420 MessageSegment::Thinking { text, signature }
421 }
422 SerializedMessageSegment::RedactedThinking { data } => {
423 MessageSegment::RedactedThinking(data)
424 }
425 })
426 .collect(),
427 context: message.context,
428 })
429 .collect(),
430 next_message_id,
431 last_prompt_id: PromptId::new(),
432 context: BTreeMap::default(),
433 context_by_message: HashMap::default(),
434 project_context,
435 checkpoints_by_message: HashMap::default(),
436 completion_count: 0,
437 pending_completions: Vec::new(),
438 last_restore_checkpoint: None,
439 pending_checkpoint: None,
440 project: project.clone(),
441 prompt_builder,
442 tools,
443 tool_use,
444 action_log: cx.new(|_| ActionLog::new(project)),
445 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
446 request_token_usage: serialized.request_token_usage,
447 cumulative_token_usage: serialized.cumulative_token_usage,
448 exceeded_window_error: None,
449 feedback: None,
450 message_feedback: HashMap::default(),
451 last_auto_capture_at: None,
452 request_callback: None,
453 }
454 }
455
456 pub fn set_request_callback(
457 &mut self,
458 callback: impl 'static
459 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
460 ) {
461 self.request_callback = Some(Box::new(callback));
462 }
463
464 pub fn id(&self) -> &ThreadId {
465 &self.id
466 }
467
468 pub fn is_empty(&self) -> bool {
469 self.messages.is_empty()
470 }
471
472 pub fn updated_at(&self) -> DateTime<Utc> {
473 self.updated_at
474 }
475
476 pub fn touch_updated_at(&mut self) {
477 self.updated_at = Utc::now();
478 }
479
480 pub fn advance_prompt_id(&mut self) {
481 self.last_prompt_id = PromptId::new();
482 }
483
484 pub fn summary(&self) -> Option<SharedString> {
485 self.summary.clone()
486 }
487
488 pub fn project_context(&self) -> SharedProjectContext {
489 self.project_context.clone()
490 }
491
492 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
493
494 pub fn summary_or_default(&self) -> SharedString {
495 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
496 }
497
498 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
499 let Some(current_summary) = &self.summary else {
500 // Don't allow setting summary until generated
501 return;
502 };
503
504 let mut new_summary = new_summary.into();
505
506 if new_summary.is_empty() {
507 new_summary = Self::DEFAULT_SUMMARY;
508 }
509
510 if current_summary != &new_summary {
511 self.summary = Some(new_summary);
512 cx.emit(ThreadEvent::SummaryChanged);
513 }
514 }
515
516 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
517 self.latest_detailed_summary()
518 .unwrap_or_else(|| self.text().into())
519 }
520
521 fn latest_detailed_summary(&self) -> Option<SharedString> {
522 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
523 Some(text.clone())
524 } else {
525 None
526 }
527 }
528
529 pub fn message(&self, id: MessageId) -> Option<&Message> {
530 self.messages.iter().find(|message| message.id == id)
531 }
532
533 pub fn messages(&self) -> impl Iterator<Item = &Message> {
534 self.messages.iter()
535 }
536
537 pub fn is_generating(&self) -> bool {
538 !self.pending_completions.is_empty() || !self.all_tools_finished()
539 }
540
541 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
542 &self.tools
543 }
544
545 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
546 self.tool_use
547 .pending_tool_uses()
548 .into_iter()
549 .find(|tool_use| &tool_use.id == id)
550 }
551
552 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
553 self.tool_use
554 .pending_tool_uses()
555 .into_iter()
556 .filter(|tool_use| tool_use.status.needs_confirmation())
557 }
558
559 pub fn has_pending_tool_uses(&self) -> bool {
560 !self.tool_use.pending_tool_uses().is_empty()
561 }
562
563 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
564 self.checkpoints_by_message.get(&id).cloned()
565 }
566
567 pub fn restore_checkpoint(
568 &mut self,
569 checkpoint: ThreadCheckpoint,
570 cx: &mut Context<Self>,
571 ) -> Task<Result<()>> {
572 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
573 message_id: checkpoint.message_id,
574 });
575 cx.emit(ThreadEvent::CheckpointChanged);
576 cx.notify();
577
578 let git_store = self.project().read(cx).git_store().clone();
579 let restore = git_store.update(cx, |git_store, cx| {
580 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
581 });
582
583 cx.spawn(async move |this, cx| {
584 let result = restore.await;
585 this.update(cx, |this, cx| {
586 if let Err(err) = result.as_ref() {
587 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
588 message_id: checkpoint.message_id,
589 error: err.to_string(),
590 });
591 } else {
592 this.truncate(checkpoint.message_id, cx);
593 this.last_restore_checkpoint = None;
594 }
595 this.pending_checkpoint = None;
596 cx.emit(ThreadEvent::CheckpointChanged);
597 cx.notify();
598 })?;
599 result
600 })
601 }
602
603 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
604 let pending_checkpoint = if self.is_generating() {
605 return;
606 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
607 checkpoint
608 } else {
609 return;
610 };
611
612 let git_store = self.project.read(cx).git_store().clone();
613 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
614 cx.spawn(async move |this, cx| match final_checkpoint.await {
615 Ok(final_checkpoint) => {
616 let equal = git_store
617 .update(cx, |store, cx| {
618 store.compare_checkpoints(
619 pending_checkpoint.git_checkpoint.clone(),
620 final_checkpoint.clone(),
621 cx,
622 )
623 })?
624 .await
625 .unwrap_or(false);
626
627 if equal {
628 git_store
629 .update(cx, |store, cx| {
630 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
631 })?
632 .detach();
633 } else {
634 this.update(cx, |this, cx| {
635 this.insert_checkpoint(pending_checkpoint, cx)
636 })?;
637 }
638
639 git_store
640 .update(cx, |store, cx| {
641 store.delete_checkpoint(final_checkpoint, cx)
642 })?
643 .detach();
644
645 Ok(())
646 }
647 Err(_) => this.update(cx, |this, cx| {
648 this.insert_checkpoint(pending_checkpoint, cx)
649 }),
650 })
651 .detach();
652 }
653
654 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
655 self.checkpoints_by_message
656 .insert(checkpoint.message_id, checkpoint);
657 cx.emit(ThreadEvent::CheckpointChanged);
658 cx.notify();
659 }
660
661 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
662 self.last_restore_checkpoint.as_ref()
663 }
664
665 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
666 let Some(message_ix) = self
667 .messages
668 .iter()
669 .rposition(|message| message.id == message_id)
670 else {
671 return;
672 };
673 for deleted_message in self.messages.drain(message_ix..) {
674 self.context_by_message.remove(&deleted_message.id);
675 self.checkpoints_by_message.remove(&deleted_message.id);
676 }
677 cx.notify();
678 }
679
680 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
681 self.context_by_message
682 .get(&id)
683 .into_iter()
684 .flat_map(|context| {
685 context
686 .iter()
687 .filter_map(|context_id| self.context.get(&context_id))
688 })
689 }
690
691 /// Returns whether all of the tool uses have finished running.
692 pub fn all_tools_finished(&self) -> bool {
693 // If the only pending tool uses left are the ones with errors, then
694 // that means that we've finished running all of the pending tools.
695 self.tool_use
696 .pending_tool_uses()
697 .iter()
698 .all(|tool_use| tool_use.status.is_error())
699 }
700
701 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
702 self.tool_use.tool_uses_for_message(id, cx)
703 }
704
705 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
706 self.tool_use.tool_results_for_message(id)
707 }
708
709 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
710 self.tool_use.tool_result(id)
711 }
712
713 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
714 Some(&self.tool_use.tool_result(id)?.content)
715 }
716
717 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
718 self.tool_use.tool_result_card(id).cloned()
719 }
720
721 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
722 self.tool_use.message_has_tool_results(message_id)
723 }
724
725 /// Filter out contexts that have already been included in previous messages
726 pub fn filter_new_context<'a>(
727 &self,
728 context: impl Iterator<Item = &'a AssistantContext>,
729 ) -> impl Iterator<Item = &'a AssistantContext> {
730 context.filter(|ctx| self.is_context_new(ctx))
731 }
732
733 fn is_context_new(&self, context: &AssistantContext) -> bool {
734 !self.context.contains_key(&context.id())
735 }
736
737 pub fn insert_user_message(
738 &mut self,
739 text: impl Into<String>,
740 context: Vec<AssistantContext>,
741 git_checkpoint: Option<GitStoreCheckpoint>,
742 cx: &mut Context<Self>,
743 ) -> MessageId {
744 let text = text.into();
745
746 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
747
748 let new_context: Vec<_> = context
749 .into_iter()
750 .filter(|ctx| self.is_context_new(ctx))
751 .collect();
752
753 if !new_context.is_empty() {
754 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
755 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
756 message.context = context_string;
757 }
758 }
759
760 self.action_log.update(cx, |log, cx| {
761 // Track all buffers added as context
762 for ctx in &new_context {
763 match ctx {
764 AssistantContext::File(file_ctx) => {
765 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
766 }
767 AssistantContext::Directory(dir_ctx) => {
768 for context_buffer in &dir_ctx.context_buffers {
769 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
770 }
771 }
772 AssistantContext::Symbol(symbol_ctx) => {
773 log.buffer_added_as_context(
774 symbol_ctx.context_symbol.buffer.clone(),
775 cx,
776 );
777 }
778 AssistantContext::Excerpt(excerpt_context) => {
779 log.buffer_added_as_context(
780 excerpt_context.context_buffer.buffer.clone(),
781 cx,
782 );
783 }
784 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
785 }
786 }
787 });
788 }
789
790 let context_ids = new_context
791 .iter()
792 .map(|context| context.id())
793 .collect::<Vec<_>>();
794 self.context.extend(
795 new_context
796 .into_iter()
797 .map(|context| (context.id(), context)),
798 );
799 self.context_by_message.insert(message_id, context_ids);
800
801 if let Some(git_checkpoint) = git_checkpoint {
802 self.pending_checkpoint = Some(ThreadCheckpoint {
803 message_id,
804 git_checkpoint,
805 });
806 }
807
808 self.auto_capture_telemetry(cx);
809
810 message_id
811 }
812
813 pub fn insert_message(
814 &mut self,
815 role: Role,
816 segments: Vec<MessageSegment>,
817 cx: &mut Context<Self>,
818 ) -> MessageId {
819 let id = self.next_message_id.post_inc();
820 self.messages.push(Message {
821 id,
822 role,
823 segments,
824 context: String::new(),
825 });
826 self.touch_updated_at();
827 cx.emit(ThreadEvent::MessageAdded(id));
828 id
829 }
830
831 pub fn edit_message(
832 &mut self,
833 id: MessageId,
834 new_role: Role,
835 new_segments: Vec<MessageSegment>,
836 cx: &mut Context<Self>,
837 ) -> bool {
838 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
839 return false;
840 };
841 message.role = new_role;
842 message.segments = new_segments;
843 self.touch_updated_at();
844 cx.emit(ThreadEvent::MessageEdited(id));
845 true
846 }
847
848 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
849 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
850 return false;
851 };
852 self.messages.remove(index);
853 self.context_by_message.remove(&id);
854 self.touch_updated_at();
855 cx.emit(ThreadEvent::MessageDeleted(id));
856 true
857 }
858
859 /// Returns the representation of this [`Thread`] in a textual form.
860 ///
861 /// This is the representation we use when attaching a thread as context to another thread.
862 pub fn text(&self) -> String {
863 let mut text = String::new();
864
865 for message in &self.messages {
866 text.push_str(match message.role {
867 language_model::Role::User => "User:",
868 language_model::Role::Assistant => "Assistant:",
869 language_model::Role::System => "System:",
870 });
871 text.push('\n');
872
873 for segment in &message.segments {
874 match segment {
875 MessageSegment::Text(content) => text.push_str(content),
876 MessageSegment::Thinking { text: content, .. } => {
877 text.push_str(&format!("<think>{}</think>", content))
878 }
879 MessageSegment::RedactedThinking(_) => {}
880 }
881 }
882 text.push('\n');
883 }
884
885 text
886 }
887
888 /// Serializes this thread into a format for storage or telemetry.
889 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
890 let initial_project_snapshot = self.initial_project_snapshot.clone();
891 cx.spawn(async move |this, cx| {
892 let initial_project_snapshot = initial_project_snapshot.await;
893 this.read_with(cx, |this, cx| SerializedThread {
894 version: SerializedThread::VERSION.to_string(),
895 summary: this.summary_or_default(),
896 updated_at: this.updated_at(),
897 messages: this
898 .messages()
899 .map(|message| SerializedMessage {
900 id: message.id,
901 role: message.role,
902 segments: message
903 .segments
904 .iter()
905 .map(|segment| match segment {
906 MessageSegment::Text(text) => {
907 SerializedMessageSegment::Text { text: text.clone() }
908 }
909 MessageSegment::Thinking { text, signature } => {
910 SerializedMessageSegment::Thinking {
911 text: text.clone(),
912 signature: signature.clone(),
913 }
914 }
915 MessageSegment::RedactedThinking(data) => {
916 SerializedMessageSegment::RedactedThinking {
917 data: data.clone(),
918 }
919 }
920 })
921 .collect(),
922 tool_uses: this
923 .tool_uses_for_message(message.id, cx)
924 .into_iter()
925 .map(|tool_use| SerializedToolUse {
926 id: tool_use.id,
927 name: tool_use.name,
928 input: tool_use.input,
929 })
930 .collect(),
931 tool_results: this
932 .tool_results_for_message(message.id)
933 .into_iter()
934 .map(|tool_result| SerializedToolResult {
935 tool_use_id: tool_result.tool_use_id.clone(),
936 is_error: tool_result.is_error,
937 content: tool_result.content.clone(),
938 })
939 .collect(),
940 context: message.context.clone(),
941 })
942 .collect(),
943 initial_project_snapshot,
944 cumulative_token_usage: this.cumulative_token_usage,
945 request_token_usage: this.request_token_usage.clone(),
946 detailed_summary_state: this.detailed_summary_state.clone(),
947 exceeded_window_error: this.exceeded_window_error.clone(),
948 })
949 })
950 }
951
952 pub fn send_to_model(
953 &mut self,
954 model: Arc<dyn LanguageModel>,
955 request_kind: RequestKind,
956 cx: &mut Context<Self>,
957 ) {
958 let mut request = self.to_completion_request(request_kind, cx);
959 if model.supports_tools() {
960 request.tools = {
961 let mut tools = Vec::new();
962 tools.extend(
963 self.tools()
964 .read(cx)
965 .enabled_tools(cx)
966 .into_iter()
967 .filter_map(|tool| {
968 // Skip tools that cannot be supported
969 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
970 Some(LanguageModelRequestTool {
971 name: tool.name(),
972 description: tool.description(),
973 input_schema,
974 })
975 }),
976 );
977
978 tools
979 };
980 }
981
982 self.stream_completion(request, model, cx);
983 }
984
985 pub fn used_tools_since_last_user_message(&self) -> bool {
986 for message in self.messages.iter().rev() {
987 if self.tool_use.message_has_tool_results(message.id) {
988 return true;
989 } else if message.role == Role::User {
990 return false;
991 }
992 }
993
994 false
995 }
996
997 pub fn to_completion_request(
998 &self,
999 request_kind: RequestKind,
1000 cx: &mut Context<Self>,
1001 ) -> LanguageModelRequest {
1002 let mut request = LanguageModelRequest {
1003 thread_id: Some(self.id.to_string()),
1004 prompt_id: Some(self.last_prompt_id.to_string()),
1005 messages: vec![],
1006 tools: Vec::new(),
1007 stop: Vec::new(),
1008 temperature: None,
1009 };
1010
1011 if let Some(project_context) = self.project_context.borrow().as_ref() {
1012 match self
1013 .prompt_builder
1014 .generate_assistant_system_prompt(project_context)
1015 {
1016 Err(err) => {
1017 let message = format!("{err:?}").into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024 Ok(system_prompt) => {
1025 request.messages.push(LanguageModelRequestMessage {
1026 role: Role::System,
1027 content: vec![MessageContent::Text(system_prompt)],
1028 cache: true,
1029 });
1030 }
1031 }
1032 } else {
1033 let message = "Context for system prompt unexpectedly not ready.".into();
1034 log::error!("{message}");
1035 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1036 header: "Error generating system prompt".into(),
1037 message,
1038 }));
1039 }
1040
1041 for message in &self.messages {
1042 let mut request_message = LanguageModelRequestMessage {
1043 role: message.role,
1044 content: Vec::new(),
1045 cache: false,
1046 };
1047
1048 match request_kind {
1049 RequestKind::Chat => {
1050 self.tool_use
1051 .attach_tool_results(message.id, &mut request_message);
1052 }
1053 RequestKind::Summarize => {
1054 // We don't care about tool use during summarization.
1055 if self.tool_use.message_has_tool_results(message.id) {
1056 continue;
1057 }
1058 }
1059 }
1060
1061 if !message.context.is_empty() {
1062 request_message
1063 .content
1064 .push(MessageContent::Text(message.context.to_string()));
1065 }
1066
1067 for segment in &message.segments {
1068 match segment {
1069 MessageSegment::Text(text) => {
1070 if !text.is_empty() {
1071 request_message
1072 .content
1073 .push(MessageContent::Text(text.into()));
1074 }
1075 }
1076 MessageSegment::Thinking { text, signature } => {
1077 if !text.is_empty() {
1078 request_message.content.push(MessageContent::Thinking {
1079 text: text.into(),
1080 signature: signature.clone(),
1081 });
1082 }
1083 }
1084 MessageSegment::RedactedThinking(data) => {
1085 request_message
1086 .content
1087 .push(MessageContent::RedactedThinking(data.clone()));
1088 }
1089 };
1090 }
1091
1092 match request_kind {
1093 RequestKind::Chat => {
1094 self.tool_use
1095 .attach_tool_uses(message.id, &mut request_message);
1096 }
1097 RequestKind::Summarize => {
1098 // We don't care about tool use during summarization.
1099 }
1100 };
1101
1102 request.messages.push(request_message);
1103 }
1104
1105 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1106 if let Some(last) = request.messages.last_mut() {
1107 last.cache = true;
1108 }
1109
1110 self.attached_tracked_files_state(&mut request.messages, cx);
1111
1112 request
1113 }
1114
1115 fn attached_tracked_files_state(
1116 &self,
1117 messages: &mut Vec<LanguageModelRequestMessage>,
1118 cx: &App,
1119 ) {
1120 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1121
1122 let mut stale_message = String::new();
1123
1124 let action_log = self.action_log.read(cx);
1125
1126 for stale_file in action_log.stale_buffers(cx) {
1127 let Some(file) = stale_file.read(cx).file() else {
1128 continue;
1129 };
1130
1131 if stale_message.is_empty() {
1132 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1133 }
1134
1135 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1136 }
1137
1138 let mut content = Vec::with_capacity(2);
1139
1140 if !stale_message.is_empty() {
1141 content.push(stale_message.into());
1142 }
1143
1144 if !content.is_empty() {
1145 let context_message = LanguageModelRequestMessage {
1146 role: Role::User,
1147 content,
1148 cache: false,
1149 };
1150
1151 messages.push(context_message);
1152 }
1153 }
1154
1155 pub fn stream_completion(
1156 &mut self,
1157 request: LanguageModelRequest,
1158 model: Arc<dyn LanguageModel>,
1159 cx: &mut Context<Self>,
1160 ) {
1161 let pending_completion_id = post_inc(&mut self.completion_count);
1162 let mut request_callback_parameters = if self.request_callback.is_some() {
1163 Some((request.clone(), Vec::new()))
1164 } else {
1165 None
1166 };
1167 let prompt_id = self.last_prompt_id.clone();
1168 let task = cx.spawn(async move |thread, cx| {
1169 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1170 let initial_token_usage =
1171 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1172 let stream_completion = async {
1173 let (mut events, usage) = stream_completion_future.await?;
1174
1175 let mut stop_reason = StopReason::EndTurn;
1176 let mut current_token_usage = TokenUsage::default();
1177
1178 if let Some(usage) = usage {
1179 thread
1180 .update(cx, |_thread, cx| {
1181 cx.emit(ThreadEvent::UsageUpdated(usage));
1182 })
1183 .ok();
1184 }
1185
1186 while let Some(event) = events.next().await {
1187 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1188 response_events
1189 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1190 }
1191
1192 let event = event?;
1193
1194 thread.update(cx, |thread, cx| {
1195 match event {
1196 LanguageModelCompletionEvent::StartMessage { .. } => {
1197 thread.insert_message(
1198 Role::Assistant,
1199 vec![MessageSegment::Text(String::new())],
1200 cx,
1201 );
1202 }
1203 LanguageModelCompletionEvent::Stop(reason) => {
1204 stop_reason = reason;
1205 }
1206 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1207 thread.update_token_usage_at_last_message(token_usage);
1208 thread.cumulative_token_usage = thread.cumulative_token_usage
1209 + token_usage
1210 - current_token_usage;
1211 current_token_usage = token_usage;
1212 }
1213 LanguageModelCompletionEvent::Text(chunk) => {
1214 if let Some(last_message) = thread.messages.last_mut() {
1215 if last_message.role == Role::Assistant {
1216 last_message.push_text(&chunk);
1217 cx.emit(ThreadEvent::StreamedAssistantText(
1218 last_message.id,
1219 chunk,
1220 ));
1221 } else {
1222 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1223 // of a new Assistant response.
1224 //
1225 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1226 // will result in duplicating the text of the chunk in the rendered Markdown.
1227 thread.insert_message(
1228 Role::Assistant,
1229 vec![MessageSegment::Text(chunk.to_string())],
1230 cx,
1231 );
1232 };
1233 }
1234 }
1235 LanguageModelCompletionEvent::Thinking {
1236 text: chunk,
1237 signature,
1238 } => {
1239 if let Some(last_message) = thread.messages.last_mut() {
1240 if last_message.role == Role::Assistant {
1241 last_message.push_thinking(&chunk, signature);
1242 cx.emit(ThreadEvent::StreamedAssistantThinking(
1243 last_message.id,
1244 chunk,
1245 ));
1246 } else {
1247 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1248 // of a new Assistant response.
1249 //
1250 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1251 // will result in duplicating the text of the chunk in the rendered Markdown.
1252 thread.insert_message(
1253 Role::Assistant,
1254 vec![MessageSegment::Thinking {
1255 text: chunk.to_string(),
1256 signature,
1257 }],
1258 cx,
1259 );
1260 };
1261 }
1262 }
1263 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1264 let last_assistant_message_id = thread
1265 .messages
1266 .iter_mut()
1267 .rfind(|message| message.role == Role::Assistant)
1268 .map(|message| message.id)
1269 .unwrap_or_else(|| {
1270 thread.insert_message(Role::Assistant, vec![], cx)
1271 });
1272
1273 thread.tool_use.request_tool_use(
1274 last_assistant_message_id,
1275 tool_use,
1276 cx,
1277 );
1278 }
1279 }
1280
1281 thread.touch_updated_at();
1282 cx.emit(ThreadEvent::StreamedCompletion);
1283 cx.notify();
1284
1285 thread.auto_capture_telemetry(cx);
1286 })?;
1287
1288 smol::future::yield_now().await;
1289 }
1290
1291 thread.update(cx, |thread, cx| {
1292 thread
1293 .pending_completions
1294 .retain(|completion| completion.id != pending_completion_id);
1295
1296 if thread.summary.is_none() && thread.messages.len() >= 2 {
1297 thread.summarize(cx);
1298 }
1299 })?;
1300
1301 anyhow::Ok(stop_reason)
1302 };
1303
1304 let result = stream_completion.await;
1305
1306 thread
1307 .update(cx, |thread, cx| {
1308 thread.finalize_pending_checkpoint(cx);
1309 match result.as_ref() {
1310 Ok(stop_reason) => match stop_reason {
1311 StopReason::ToolUse => {
1312 let tool_uses = thread.use_pending_tools(cx);
1313 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1314 }
1315 StopReason::EndTurn => {}
1316 StopReason::MaxTokens => {}
1317 },
1318 Err(error) => {
1319 if error.is::<PaymentRequiredError>() {
1320 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1321 } else if error.is::<MaxMonthlySpendReachedError>() {
1322 cx.emit(ThreadEvent::ShowError(
1323 ThreadError::MaxMonthlySpendReached,
1324 ));
1325 } else if let Some(error) =
1326 error.downcast_ref::<ModelRequestLimitReachedError>()
1327 {
1328 cx.emit(ThreadEvent::ShowError(
1329 ThreadError::ModelRequestLimitReached { plan: error.plan },
1330 ));
1331 } else if let Some(known_error) =
1332 error.downcast_ref::<LanguageModelKnownError>()
1333 {
1334 match known_error {
1335 LanguageModelKnownError::ContextWindowLimitExceeded {
1336 tokens,
1337 } => {
1338 thread.exceeded_window_error = Some(ExceededWindowError {
1339 model_id: model.id(),
1340 token_count: *tokens,
1341 });
1342 cx.notify();
1343 }
1344 }
1345 } else {
1346 let error_message = error
1347 .chain()
1348 .map(|err| err.to_string())
1349 .collect::<Vec<_>>()
1350 .join("\n");
1351 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1352 header: "Error interacting with language model".into(),
1353 message: SharedString::from(error_message.clone()),
1354 }));
1355 }
1356
1357 thread.cancel_last_completion(cx);
1358 }
1359 }
1360 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1361
1362 if let Some((request_callback, (request, response_events))) = thread
1363 .request_callback
1364 .as_mut()
1365 .zip(request_callback_parameters.as_ref())
1366 {
1367 request_callback(request, response_events);
1368 }
1369
1370 thread.auto_capture_telemetry(cx);
1371
1372 if let Ok(initial_usage) = initial_token_usage {
1373 let usage = thread.cumulative_token_usage - initial_usage;
1374
1375 telemetry::event!(
1376 "Assistant Thread Completion",
1377 thread_id = thread.id().to_string(),
1378 prompt_id = prompt_id,
1379 model = model.telemetry_id(),
1380 model_provider = model.provider_id().to_string(),
1381 input_tokens = usage.input_tokens,
1382 output_tokens = usage.output_tokens,
1383 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1384 cache_read_input_tokens = usage.cache_read_input_tokens,
1385 );
1386 }
1387 })
1388 .ok();
1389 });
1390
1391 self.pending_completions.push(PendingCompletion {
1392 id: pending_completion_id,
1393 _task: task,
1394 });
1395 }
1396
1397 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1398 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1399 return;
1400 };
1401
1402 if !model.provider.is_authenticated(cx) {
1403 return;
1404 }
1405
1406 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1407 request.messages.push(LanguageModelRequestMessage {
1408 role: Role::User,
1409 content: vec![
1410 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1411 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1412 If the conversation is about a specific subject, include it in the title. \
1413 Be descriptive. DO NOT speak in the first person."
1414 .into(),
1415 ],
1416 cache: false,
1417 });
1418
1419 self.pending_summary = cx.spawn(async move |this, cx| {
1420 async move {
1421 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1422 let (mut messages, usage) = stream.await?;
1423
1424 if let Some(usage) = usage {
1425 this.update(cx, |_thread, cx| {
1426 cx.emit(ThreadEvent::UsageUpdated(usage));
1427 })
1428 .ok();
1429 }
1430
1431 let mut new_summary = String::new();
1432 while let Some(message) = messages.stream.next().await {
1433 let text = message?;
1434 let mut lines = text.lines();
1435 new_summary.extend(lines.next());
1436
1437 // Stop if the LLM generated multiple lines.
1438 if lines.next().is_some() {
1439 break;
1440 }
1441 }
1442
1443 this.update(cx, |this, cx| {
1444 if !new_summary.is_empty() {
1445 this.summary = Some(new_summary.into());
1446 }
1447
1448 cx.emit(ThreadEvent::SummaryGenerated);
1449 })?;
1450
1451 anyhow::Ok(())
1452 }
1453 .log_err()
1454 .await
1455 });
1456 }
1457
1458 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1459 let last_message_id = self.messages.last().map(|message| message.id)?;
1460
1461 match &self.detailed_summary_state {
1462 DetailedSummaryState::Generating { message_id, .. }
1463 | DetailedSummaryState::Generated { message_id, .. }
1464 if *message_id == last_message_id =>
1465 {
1466 // Already up-to-date
1467 return None;
1468 }
1469 _ => {}
1470 }
1471
1472 let ConfiguredModel { model, provider } =
1473 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1474
1475 if !provider.is_authenticated(cx) {
1476 return None;
1477 }
1478
1479 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1480
1481 request.messages.push(LanguageModelRequestMessage {
1482 role: Role::User,
1483 content: vec![
1484 "Generate a detailed summary of this conversation. Include:\n\
1485 1. A brief overview of what was discussed\n\
1486 2. Key facts or information discovered\n\
1487 3. Outcomes or conclusions reached\n\
1488 4. Any action items or next steps if any\n\
1489 Format it in Markdown with headings and bullet points."
1490 .into(),
1491 ],
1492 cache: false,
1493 });
1494
1495 let task = cx.spawn(async move |thread, cx| {
1496 let stream = model.stream_completion_text(request, &cx);
1497 let Some(mut messages) = stream.await.log_err() else {
1498 thread
1499 .update(cx, |this, _cx| {
1500 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1501 })
1502 .log_err();
1503
1504 return;
1505 };
1506
1507 let mut new_detailed_summary = String::new();
1508
1509 while let Some(chunk) = messages.stream.next().await {
1510 if let Some(chunk) = chunk.log_err() {
1511 new_detailed_summary.push_str(&chunk);
1512 }
1513 }
1514
1515 thread
1516 .update(cx, |this, _cx| {
1517 this.detailed_summary_state = DetailedSummaryState::Generated {
1518 text: new_detailed_summary.into(),
1519 message_id: last_message_id,
1520 };
1521 })
1522 .log_err();
1523 });
1524
1525 self.detailed_summary_state = DetailedSummaryState::Generating {
1526 message_id: last_message_id,
1527 };
1528
1529 Some(task)
1530 }
1531
1532 pub fn is_generating_detailed_summary(&self) -> bool {
1533 matches!(
1534 self.detailed_summary_state,
1535 DetailedSummaryState::Generating { .. }
1536 )
1537 }
1538
1539 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1540 self.auto_capture_telemetry(cx);
1541 let request = self.to_completion_request(RequestKind::Chat, cx);
1542 let messages = Arc::new(request.messages);
1543 let pending_tool_uses = self
1544 .tool_use
1545 .pending_tool_uses()
1546 .into_iter()
1547 .filter(|tool_use| tool_use.status.is_idle())
1548 .cloned()
1549 .collect::<Vec<_>>();
1550
1551 for tool_use in pending_tool_uses.iter() {
1552 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1553 if tool.needs_confirmation(&tool_use.input, cx)
1554 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1555 {
1556 self.tool_use.confirm_tool_use(
1557 tool_use.id.clone(),
1558 tool_use.ui_text.clone(),
1559 tool_use.input.clone(),
1560 messages.clone(),
1561 tool,
1562 );
1563 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1564 } else {
1565 self.run_tool(
1566 tool_use.id.clone(),
1567 tool_use.ui_text.clone(),
1568 tool_use.input.clone(),
1569 &messages,
1570 tool,
1571 cx,
1572 );
1573 }
1574 }
1575 }
1576
1577 pending_tool_uses
1578 }
1579
1580 pub fn run_tool(
1581 &mut self,
1582 tool_use_id: LanguageModelToolUseId,
1583 ui_text: impl Into<SharedString>,
1584 input: serde_json::Value,
1585 messages: &[LanguageModelRequestMessage],
1586 tool: Arc<dyn Tool>,
1587 cx: &mut Context<Thread>,
1588 ) {
1589 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1590 self.tool_use
1591 .run_pending_tool(tool_use_id, ui_text.into(), task);
1592 }
1593
1594 fn spawn_tool_use(
1595 &mut self,
1596 tool_use_id: LanguageModelToolUseId,
1597 messages: &[LanguageModelRequestMessage],
1598 input: serde_json::Value,
1599 tool: Arc<dyn Tool>,
1600 cx: &mut Context<Thread>,
1601 ) -> Task<()> {
1602 let tool_name: Arc<str> = tool.name().into();
1603
1604 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1605 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1606 } else {
1607 tool.run(
1608 input,
1609 messages,
1610 self.project.clone(),
1611 self.action_log.clone(),
1612 cx,
1613 )
1614 };
1615
1616 // Store the card separately if it exists
1617 if let Some(card) = tool_result.card.clone() {
1618 self.tool_use
1619 .insert_tool_result_card(tool_use_id.clone(), card);
1620 }
1621
1622 cx.spawn({
1623 async move |thread: WeakEntity<Thread>, cx| {
1624 let output = tool_result.output.await;
1625
1626 thread
1627 .update(cx, |thread, cx| {
1628 let pending_tool_use = thread.tool_use.insert_tool_output(
1629 tool_use_id.clone(),
1630 tool_name,
1631 output,
1632 cx,
1633 );
1634 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1635 })
1636 .ok();
1637 }
1638 })
1639 }
1640
1641 fn tool_finished(
1642 &mut self,
1643 tool_use_id: LanguageModelToolUseId,
1644 pending_tool_use: Option<PendingToolUse>,
1645 canceled: bool,
1646 cx: &mut Context<Self>,
1647 ) {
1648 if self.all_tools_finished() {
1649 let model_registry = LanguageModelRegistry::read_global(cx);
1650 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1651 self.attach_tool_results(cx);
1652 if !canceled {
1653 self.send_to_model(model, RequestKind::Chat, cx);
1654 }
1655 }
1656 }
1657
1658 cx.emit(ThreadEvent::ToolFinished {
1659 tool_use_id,
1660 pending_tool_use,
1661 });
1662 }
1663
1664 /// Insert an empty message to be populated with tool results upon send.
1665 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1666 // Tool results are assumed to be waiting on the next message id, so they will populate
1667 // this empty message before sending to model. Would prefer this to be more straightforward.
1668 self.insert_message(Role::User, vec![], cx);
1669 self.auto_capture_telemetry(cx);
1670 }
1671
1672 /// Cancels the last pending completion, if there are any pending.
1673 ///
1674 /// Returns whether a completion was canceled.
1675 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1676 let canceled = if self.pending_completions.pop().is_some() {
1677 true
1678 } else {
1679 let mut canceled = false;
1680 for pending_tool_use in self.tool_use.cancel_pending() {
1681 canceled = true;
1682 self.tool_finished(
1683 pending_tool_use.id.clone(),
1684 Some(pending_tool_use),
1685 true,
1686 cx,
1687 );
1688 }
1689 canceled
1690 };
1691 self.finalize_pending_checkpoint(cx);
1692 canceled
1693 }
1694
1695 pub fn feedback(&self) -> Option<ThreadFeedback> {
1696 self.feedback
1697 }
1698
1699 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1700 self.message_feedback.get(&message_id).copied()
1701 }
1702
1703 pub fn report_message_feedback(
1704 &mut self,
1705 message_id: MessageId,
1706 feedback: ThreadFeedback,
1707 cx: &mut Context<Self>,
1708 ) -> Task<Result<()>> {
1709 if self.message_feedback.get(&message_id) == Some(&feedback) {
1710 return Task::ready(Ok(()));
1711 }
1712
1713 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1714 let serialized_thread = self.serialize(cx);
1715 let thread_id = self.id().clone();
1716 let client = self.project.read(cx).client();
1717
1718 let enabled_tool_names: Vec<String> = self
1719 .tools()
1720 .read(cx)
1721 .enabled_tools(cx)
1722 .iter()
1723 .map(|tool| tool.name().to_string())
1724 .collect();
1725
1726 self.message_feedback.insert(message_id, feedback);
1727
1728 cx.notify();
1729
1730 let message_content = self
1731 .message(message_id)
1732 .map(|msg| msg.to_string())
1733 .unwrap_or_default();
1734
1735 cx.background_spawn(async move {
1736 let final_project_snapshot = final_project_snapshot.await;
1737 let serialized_thread = serialized_thread.await?;
1738 let thread_data =
1739 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1740
1741 let rating = match feedback {
1742 ThreadFeedback::Positive => "positive",
1743 ThreadFeedback::Negative => "negative",
1744 };
1745 telemetry::event!(
1746 "Assistant Thread Rated",
1747 rating,
1748 thread_id,
1749 enabled_tool_names,
1750 message_id = message_id.0,
1751 message_content,
1752 thread_data,
1753 final_project_snapshot
1754 );
1755 client.telemetry().flush_events();
1756
1757 Ok(())
1758 })
1759 }
1760
1761 pub fn report_feedback(
1762 &mut self,
1763 feedback: ThreadFeedback,
1764 cx: &mut Context<Self>,
1765 ) -> Task<Result<()>> {
1766 let last_assistant_message_id = self
1767 .messages
1768 .iter()
1769 .rev()
1770 .find(|msg| msg.role == Role::Assistant)
1771 .map(|msg| msg.id);
1772
1773 if let Some(message_id) = last_assistant_message_id {
1774 self.report_message_feedback(message_id, feedback, cx)
1775 } else {
1776 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1777 let serialized_thread = self.serialize(cx);
1778 let thread_id = self.id().clone();
1779 let client = self.project.read(cx).client();
1780 self.feedback = Some(feedback);
1781 cx.notify();
1782
1783 cx.background_spawn(async move {
1784 let final_project_snapshot = final_project_snapshot.await;
1785 let serialized_thread = serialized_thread.await?;
1786 let thread_data = serde_json::to_value(serialized_thread)
1787 .unwrap_or_else(|_| serde_json::Value::Null);
1788
1789 let rating = match feedback {
1790 ThreadFeedback::Positive => "positive",
1791 ThreadFeedback::Negative => "negative",
1792 };
1793 telemetry::event!(
1794 "Assistant Thread Rated",
1795 rating,
1796 thread_id,
1797 thread_data,
1798 final_project_snapshot
1799 );
1800 client.telemetry().flush_events();
1801
1802 Ok(())
1803 })
1804 }
1805 }
1806
1807 /// Create a snapshot of the current project state including git information and unsaved buffers.
1808 fn project_snapshot(
1809 project: Entity<Project>,
1810 cx: &mut Context<Self>,
1811 ) -> Task<Arc<ProjectSnapshot>> {
1812 let git_store = project.read(cx).git_store().clone();
1813 let worktree_snapshots: Vec<_> = project
1814 .read(cx)
1815 .visible_worktrees(cx)
1816 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1817 .collect();
1818
1819 cx.spawn(async move |_, cx| {
1820 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1821
1822 let mut unsaved_buffers = Vec::new();
1823 cx.update(|app_cx| {
1824 let buffer_store = project.read(app_cx).buffer_store();
1825 for buffer_handle in buffer_store.read(app_cx).buffers() {
1826 let buffer = buffer_handle.read(app_cx);
1827 if buffer.is_dirty() {
1828 if let Some(file) = buffer.file() {
1829 let path = file.path().to_string_lossy().to_string();
1830 unsaved_buffers.push(path);
1831 }
1832 }
1833 }
1834 })
1835 .ok();
1836
1837 Arc::new(ProjectSnapshot {
1838 worktree_snapshots,
1839 unsaved_buffer_paths: unsaved_buffers,
1840 timestamp: Utc::now(),
1841 })
1842 })
1843 }
1844
1845 fn worktree_snapshot(
1846 worktree: Entity<project::Worktree>,
1847 git_store: Entity<GitStore>,
1848 cx: &App,
1849 ) -> Task<WorktreeSnapshot> {
1850 cx.spawn(async move |cx| {
1851 // Get worktree path and snapshot
1852 let worktree_info = cx.update(|app_cx| {
1853 let worktree = worktree.read(app_cx);
1854 let path = worktree.abs_path().to_string_lossy().to_string();
1855 let snapshot = worktree.snapshot();
1856 (path, snapshot)
1857 });
1858
1859 let Ok((worktree_path, _snapshot)) = worktree_info else {
1860 return WorktreeSnapshot {
1861 worktree_path: String::new(),
1862 git_state: None,
1863 };
1864 };
1865
1866 let git_state = git_store
1867 .update(cx, |git_store, cx| {
1868 git_store
1869 .repositories()
1870 .values()
1871 .find(|repo| {
1872 repo.read(cx)
1873 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1874 .is_some()
1875 })
1876 .cloned()
1877 })
1878 .ok()
1879 .flatten()
1880 .map(|repo| {
1881 repo.update(cx, |repo, _| {
1882 let current_branch =
1883 repo.branch.as_ref().map(|branch| branch.name.to_string());
1884 repo.send_job(None, |state, _| async move {
1885 let RepositoryState::Local { backend, .. } = state else {
1886 return GitState {
1887 remote_url: None,
1888 head_sha: None,
1889 current_branch,
1890 diff: None,
1891 };
1892 };
1893
1894 let remote_url = backend.remote_url("origin");
1895 let head_sha = backend.head_sha();
1896 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1897
1898 GitState {
1899 remote_url,
1900 head_sha,
1901 current_branch,
1902 diff,
1903 }
1904 })
1905 })
1906 });
1907
1908 let git_state = match git_state {
1909 Some(git_state) => match git_state.ok() {
1910 Some(git_state) => git_state.await.ok(),
1911 None => None,
1912 },
1913 None => None,
1914 };
1915
1916 WorktreeSnapshot {
1917 worktree_path,
1918 git_state,
1919 }
1920 })
1921 }
1922
1923 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1924 let mut markdown = Vec::new();
1925
1926 if let Some(summary) = self.summary() {
1927 writeln!(markdown, "# {summary}\n")?;
1928 };
1929
1930 for message in self.messages() {
1931 writeln!(
1932 markdown,
1933 "## {role}\n",
1934 role = match message.role {
1935 Role::User => "User",
1936 Role::Assistant => "Assistant",
1937 Role::System => "System",
1938 }
1939 )?;
1940
1941 if !message.context.is_empty() {
1942 writeln!(markdown, "{}", message.context)?;
1943 }
1944
1945 for segment in &message.segments {
1946 match segment {
1947 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1948 MessageSegment::Thinking { text, .. } => {
1949 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1950 }
1951 MessageSegment::RedactedThinking(_) => {}
1952 }
1953 }
1954
1955 for tool_use in self.tool_uses_for_message(message.id, cx) {
1956 writeln!(
1957 markdown,
1958 "**Use Tool: {} ({})**",
1959 tool_use.name, tool_use.id
1960 )?;
1961 writeln!(markdown, "```json")?;
1962 writeln!(
1963 markdown,
1964 "{}",
1965 serde_json::to_string_pretty(&tool_use.input)?
1966 )?;
1967 writeln!(markdown, "```")?;
1968 }
1969
1970 for tool_result in self.tool_results_for_message(message.id) {
1971 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1972 if tool_result.is_error {
1973 write!(markdown, " (Error)")?;
1974 }
1975
1976 writeln!(markdown, "**\n")?;
1977 writeln!(markdown, "{}", tool_result.content)?;
1978 }
1979 }
1980
1981 Ok(String::from_utf8_lossy(&markdown).to_string())
1982 }
1983
1984 pub fn keep_edits_in_range(
1985 &mut self,
1986 buffer: Entity<language::Buffer>,
1987 buffer_range: Range<language::Anchor>,
1988 cx: &mut Context<Self>,
1989 ) {
1990 self.action_log.update(cx, |action_log, cx| {
1991 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1992 });
1993 }
1994
1995 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1996 self.action_log
1997 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1998 }
1999
2000 pub fn reject_edits_in_ranges(
2001 &mut self,
2002 buffer: Entity<language::Buffer>,
2003 buffer_ranges: Vec<Range<language::Anchor>>,
2004 cx: &mut Context<Self>,
2005 ) -> Task<Result<()>> {
2006 self.action_log.update(cx, |action_log, cx| {
2007 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2008 })
2009 }
2010
2011 pub fn action_log(&self) -> &Entity<ActionLog> {
2012 &self.action_log
2013 }
2014
2015 pub fn project(&self) -> &Entity<Project> {
2016 &self.project
2017 }
2018
2019 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2020 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2021 return;
2022 }
2023
2024 let now = Instant::now();
2025 if let Some(last) = self.last_auto_capture_at {
2026 if now.duration_since(last).as_secs() < 10 {
2027 return;
2028 }
2029 }
2030
2031 self.last_auto_capture_at = Some(now);
2032
2033 let thread_id = self.id().clone();
2034 let github_login = self
2035 .project
2036 .read(cx)
2037 .user_store()
2038 .read(cx)
2039 .current_user()
2040 .map(|user| user.github_login.clone());
2041 let client = self.project.read(cx).client().clone();
2042 let serialize_task = self.serialize(cx);
2043
2044 cx.background_executor()
2045 .spawn(async move {
2046 if let Ok(serialized_thread) = serialize_task.await {
2047 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2048 telemetry::event!(
2049 "Agent Thread Auto-Captured",
2050 thread_id = thread_id.to_string(),
2051 thread_data = thread_data,
2052 auto_capture_reason = "tracked_user",
2053 github_login = github_login
2054 );
2055
2056 client.telemetry().flush_events();
2057 }
2058 }
2059 })
2060 .detach();
2061 }
2062
2063 pub fn cumulative_token_usage(&self) -> TokenUsage {
2064 self.cumulative_token_usage
2065 }
2066
2067 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2068 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2069 return TotalTokenUsage::default();
2070 };
2071
2072 let max = model.model.max_token_count();
2073
2074 let index = self
2075 .messages
2076 .iter()
2077 .position(|msg| msg.id == message_id)
2078 .unwrap_or(0);
2079
2080 if index == 0 {
2081 return TotalTokenUsage { total: 0, max };
2082 }
2083
2084 let token_usage = &self
2085 .request_token_usage
2086 .get(index - 1)
2087 .cloned()
2088 .unwrap_or_default();
2089
2090 TotalTokenUsage {
2091 total: token_usage.total_tokens() as usize,
2092 max,
2093 }
2094 }
2095
2096 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2097 let model_registry = LanguageModelRegistry::read_global(cx);
2098 let Some(model) = model_registry.default_model() else {
2099 return TotalTokenUsage::default();
2100 };
2101
2102 let max = model.model.max_token_count();
2103
2104 if let Some(exceeded_error) = &self.exceeded_window_error {
2105 if model.model.id() == exceeded_error.model_id {
2106 return TotalTokenUsage {
2107 total: exceeded_error.token_count,
2108 max,
2109 };
2110 }
2111 }
2112
2113 let total = self
2114 .token_usage_at_last_message()
2115 .unwrap_or_default()
2116 .total_tokens() as usize;
2117
2118 TotalTokenUsage { total, max }
2119 }
2120
2121 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2122 self.request_token_usage
2123 .get(self.messages.len().saturating_sub(1))
2124 .or_else(|| self.request_token_usage.last())
2125 .cloned()
2126 }
2127
2128 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2129 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2130 self.request_token_usage
2131 .resize(self.messages.len(), placeholder);
2132
2133 if let Some(last) = self.request_token_usage.last_mut() {
2134 *last = token_usage;
2135 }
2136 }
2137
2138 pub fn deny_tool_use(
2139 &mut self,
2140 tool_use_id: LanguageModelToolUseId,
2141 tool_name: Arc<str>,
2142 cx: &mut Context<Self>,
2143 ) {
2144 let err = Err(anyhow::anyhow!(
2145 "Permission to run tool action denied by user"
2146 ));
2147
2148 self.tool_use
2149 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2150 self.tool_finished(tool_use_id.clone(), None, true, cx);
2151 }
2152}
2153
2154#[derive(Debug, Clone, Error)]
2155pub enum ThreadError {
2156 #[error("Payment required")]
2157 PaymentRequired,
2158 #[error("Max monthly spend reached")]
2159 MaxMonthlySpendReached,
2160 #[error("Model request limit reached")]
2161 ModelRequestLimitReached { plan: Plan },
2162 #[error("Message {header}: {message}")]
2163 Message {
2164 header: SharedString,
2165 message: SharedString,
2166 },
2167}
2168
2169#[derive(Debug, Clone)]
2170pub enum ThreadEvent {
2171 ShowError(ThreadError),
2172 UsageUpdated(RequestUsage),
2173 StreamedCompletion,
2174 StreamedAssistantText(MessageId, String),
2175 StreamedAssistantThinking(MessageId, String),
2176 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2177 MessageAdded(MessageId),
2178 MessageEdited(MessageId),
2179 MessageDeleted(MessageId),
2180 SummaryGenerated,
2181 SummaryChanged,
2182 UsePendingTools {
2183 tool_uses: Vec<PendingToolUse>,
2184 },
2185 ToolFinished {
2186 #[allow(unused)]
2187 tool_use_id: LanguageModelToolUseId,
2188 /// The pending tool use that corresponds to this tool.
2189 pending_tool_use: Option<PendingToolUse>,
2190 },
2191 CheckpointChanged,
2192 ToolConfirmationNeeded,
2193}
2194
2195impl EventEmitter<ThreadEvent> for Thread {}
2196
2197struct PendingCompletion {
2198 id: usize,
2199 _task: Task<()>,
2200}
2201
2202#[cfg(test)]
2203mod tests {
2204 use super::*;
2205 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2206 use assistant_settings::AssistantSettings;
2207 use context_server::ContextServerSettings;
2208 use editor::EditorSettings;
2209 use gpui::TestAppContext;
2210 use project::{FakeFs, Project};
2211 use prompt_store::PromptBuilder;
2212 use serde_json::json;
2213 use settings::{Settings, SettingsStore};
2214 use std::sync::Arc;
2215 use theme::ThemeSettings;
2216 use util::path;
2217 use workspace::Workspace;
2218
2219 #[gpui::test]
2220 async fn test_message_with_context(cx: &mut TestAppContext) {
2221 init_test_settings(cx);
2222
2223 let project = create_test_project(
2224 cx,
2225 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2226 )
2227 .await;
2228
2229 let (_workspace, _thread_store, thread, context_store) =
2230 setup_test_environment(cx, project.clone()).await;
2231
2232 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2233 .await
2234 .unwrap();
2235
2236 let context =
2237 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2238
2239 // Insert user message with context
2240 let message_id = thread.update(cx, |thread, cx| {
2241 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2242 });
2243
2244 // Check content and context in message object
2245 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2246
2247 // Use different path format strings based on platform for the test
2248 #[cfg(windows)]
2249 let path_part = r"test\code.rs";
2250 #[cfg(not(windows))]
2251 let path_part = "test/code.rs";
2252
2253 let expected_context = format!(
2254 r#"
2255<context>
2256The following items were attached by the user. You don't need to use other tools to read them.
2257
2258<files>
2259```rs {path_part}
2260fn main() {{
2261 println!("Hello, world!");
2262}}
2263```
2264</files>
2265</context>
2266"#
2267 );
2268
2269 assert_eq!(message.role, Role::User);
2270 assert_eq!(message.segments.len(), 1);
2271 assert_eq!(
2272 message.segments[0],
2273 MessageSegment::Text("Please explain this code".to_string())
2274 );
2275 assert_eq!(message.context, expected_context);
2276
2277 // Check message in request
2278 let request = thread.update(cx, |thread, cx| {
2279 thread.to_completion_request(RequestKind::Chat, cx)
2280 });
2281
2282 assert_eq!(request.messages.len(), 2);
2283 let expected_full_message = format!("{}Please explain this code", expected_context);
2284 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2285 }
2286
2287 #[gpui::test]
2288 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2289 init_test_settings(cx);
2290
2291 let project = create_test_project(
2292 cx,
2293 json!({
2294 "file1.rs": "fn function1() {}\n",
2295 "file2.rs": "fn function2() {}\n",
2296 "file3.rs": "fn function3() {}\n",
2297 }),
2298 )
2299 .await;
2300
2301 let (_, _thread_store, thread, context_store) =
2302 setup_test_environment(cx, project.clone()).await;
2303
2304 // Open files individually
2305 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2306 .await
2307 .unwrap();
2308 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2309 .await
2310 .unwrap();
2311 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2312 .await
2313 .unwrap();
2314
2315 // Get the context objects
2316 let contexts = context_store.update(cx, |store, _| store.context().clone());
2317 assert_eq!(contexts.len(), 3);
2318
2319 // First message with context 1
2320 let message1_id = thread.update(cx, |thread, cx| {
2321 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2322 });
2323
2324 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2325 let message2_id = thread.update(cx, |thread, cx| {
2326 thread.insert_user_message(
2327 "Message 2",
2328 vec![contexts[0].clone(), contexts[1].clone()],
2329 None,
2330 cx,
2331 )
2332 });
2333
2334 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2335 let message3_id = thread.update(cx, |thread, cx| {
2336 thread.insert_user_message(
2337 "Message 3",
2338 vec![
2339 contexts[0].clone(),
2340 contexts[1].clone(),
2341 contexts[2].clone(),
2342 ],
2343 None,
2344 cx,
2345 )
2346 });
2347
2348 // Check what contexts are included in each message
2349 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2350 (
2351 thread.message(message1_id).unwrap().clone(),
2352 thread.message(message2_id).unwrap().clone(),
2353 thread.message(message3_id).unwrap().clone(),
2354 )
2355 });
2356
2357 // First message should include context 1
2358 assert!(message1.context.contains("file1.rs"));
2359
2360 // Second message should include only context 2 (not 1)
2361 assert!(!message2.context.contains("file1.rs"));
2362 assert!(message2.context.contains("file2.rs"));
2363
2364 // Third message should include only context 3 (not 1 or 2)
2365 assert!(!message3.context.contains("file1.rs"));
2366 assert!(!message3.context.contains("file2.rs"));
2367 assert!(message3.context.contains("file3.rs"));
2368
2369 // Check entire request to make sure all contexts are properly included
2370 let request = thread.update(cx, |thread, cx| {
2371 thread.to_completion_request(RequestKind::Chat, cx)
2372 });
2373
2374 // The request should contain all 3 messages
2375 assert_eq!(request.messages.len(), 4);
2376
2377 // Check that the contexts are properly formatted in each message
2378 assert!(request.messages[1].string_contents().contains("file1.rs"));
2379 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2380 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2381
2382 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2383 assert!(request.messages[2].string_contents().contains("file2.rs"));
2384 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2385
2386 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2387 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2388 assert!(request.messages[3].string_contents().contains("file3.rs"));
2389 }
2390
2391 #[gpui::test]
2392 async fn test_message_without_files(cx: &mut TestAppContext) {
2393 init_test_settings(cx);
2394
2395 let project = create_test_project(
2396 cx,
2397 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2398 )
2399 .await;
2400
2401 let (_, _thread_store, thread, _context_store) =
2402 setup_test_environment(cx, project.clone()).await;
2403
2404 // Insert user message without any context (empty context vector)
2405 let message_id = thread.update(cx, |thread, cx| {
2406 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2407 });
2408
2409 // Check content and context in message object
2410 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2411
2412 // Context should be empty when no files are included
2413 assert_eq!(message.role, Role::User);
2414 assert_eq!(message.segments.len(), 1);
2415 assert_eq!(
2416 message.segments[0],
2417 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2418 );
2419 assert_eq!(message.context, "");
2420
2421 // Check message in request
2422 let request = thread.update(cx, |thread, cx| {
2423 thread.to_completion_request(RequestKind::Chat, cx)
2424 });
2425
2426 assert_eq!(request.messages.len(), 2);
2427 assert_eq!(
2428 request.messages[1].string_contents(),
2429 "What is the best way to learn Rust?"
2430 );
2431
2432 // Add second message, also without context
2433 let message2_id = thread.update(cx, |thread, cx| {
2434 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2435 });
2436
2437 let message2 =
2438 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2439 assert_eq!(message2.context, "");
2440
2441 // Check that both messages appear in the request
2442 let request = thread.update(cx, |thread, cx| {
2443 thread.to_completion_request(RequestKind::Chat, cx)
2444 });
2445
2446 assert_eq!(request.messages.len(), 3);
2447 assert_eq!(
2448 request.messages[1].string_contents(),
2449 "What is the best way to learn Rust?"
2450 );
2451 assert_eq!(
2452 request.messages[2].string_contents(),
2453 "Are there any good books?"
2454 );
2455 }
2456
2457 #[gpui::test]
2458 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2459 init_test_settings(cx);
2460
2461 let project = create_test_project(
2462 cx,
2463 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2464 )
2465 .await;
2466
2467 let (_workspace, _thread_store, thread, context_store) =
2468 setup_test_environment(cx, project.clone()).await;
2469
2470 // Open buffer and add it to context
2471 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2472 .await
2473 .unwrap();
2474
2475 let context =
2476 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2477
2478 // Insert user message with the buffer as context
2479 thread.update(cx, |thread, cx| {
2480 thread.insert_user_message("Explain this code", vec![context], None, cx)
2481 });
2482
2483 // Create a request and check that it doesn't have a stale buffer warning yet
2484 let initial_request = thread.update(cx, |thread, cx| {
2485 thread.to_completion_request(RequestKind::Chat, cx)
2486 });
2487
2488 // Make sure we don't have a stale file warning yet
2489 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2490 msg.string_contents()
2491 .contains("These files changed since last read:")
2492 });
2493 assert!(
2494 !has_stale_warning,
2495 "Should not have stale buffer warning before buffer is modified"
2496 );
2497
2498 // Modify the buffer
2499 buffer.update(cx, |buffer, cx| {
2500 // Find a position at the end of line 1
2501 buffer.edit(
2502 [(1..1, "\n println!(\"Added a new line\");\n")],
2503 None,
2504 cx,
2505 );
2506 });
2507
2508 // Insert another user message without context
2509 thread.update(cx, |thread, cx| {
2510 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2511 });
2512
2513 // Create a new request and check for the stale buffer warning
2514 let new_request = thread.update(cx, |thread, cx| {
2515 thread.to_completion_request(RequestKind::Chat, cx)
2516 });
2517
2518 // We should have a stale file warning as the last message
2519 let last_message = new_request
2520 .messages
2521 .last()
2522 .expect("Request should have messages");
2523
2524 // The last message should be the stale buffer notification
2525 assert_eq!(last_message.role, Role::User);
2526
2527 // Check the exact content of the message
2528 let expected_content = "These files changed since last read:\n- code.rs\n";
2529 assert_eq!(
2530 last_message.string_contents(),
2531 expected_content,
2532 "Last message should be exactly the stale buffer notification"
2533 );
2534 }
2535
2536 fn init_test_settings(cx: &mut TestAppContext) {
2537 cx.update(|cx| {
2538 let settings_store = SettingsStore::test(cx);
2539 cx.set_global(settings_store);
2540 language::init(cx);
2541 Project::init_settings(cx);
2542 AssistantSettings::register(cx);
2543 prompt_store::init(cx);
2544 thread_store::init(cx);
2545 workspace::init_settings(cx);
2546 ThemeSettings::register(cx);
2547 ContextServerSettings::register(cx);
2548 EditorSettings::register(cx);
2549 });
2550 }
2551
2552 // Helper to create a test project with test files
2553 async fn create_test_project(
2554 cx: &mut TestAppContext,
2555 files: serde_json::Value,
2556 ) -> Entity<Project> {
2557 let fs = FakeFs::new(cx.executor());
2558 fs.insert_tree(path!("/test"), files).await;
2559 Project::test(fs, [path!("/test").as_ref()], cx).await
2560 }
2561
2562 async fn setup_test_environment(
2563 cx: &mut TestAppContext,
2564 project: Entity<Project>,
2565 ) -> (
2566 Entity<Workspace>,
2567 Entity<ThreadStore>,
2568 Entity<Thread>,
2569 Entity<ContextStore>,
2570 ) {
2571 let (workspace, cx) =
2572 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2573
2574 let thread_store = cx
2575 .update(|_, cx| {
2576 ThreadStore::load(
2577 project.clone(),
2578 cx.new(|_| ToolWorkingSet::default()),
2579 Arc::new(PromptBuilder::new(None).unwrap()),
2580 cx,
2581 )
2582 })
2583 .await
2584 .unwrap();
2585
2586 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2587 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2588
2589 (workspace, thread_store, thread, context_store)
2590 }
2591
2592 async fn add_file_to_context(
2593 project: &Entity<Project>,
2594 context_store: &Entity<ContextStore>,
2595 path: &str,
2596 cx: &mut TestAppContext,
2597 ) -> Result<Entity<language::Buffer>> {
2598 let buffer_path = project
2599 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2600 .unwrap();
2601
2602 let buffer = project
2603 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2604 .await
2605 .unwrap();
2606
2607 context_store
2608 .update(cx, |store, cx| {
2609 store.add_file_from_buffer(buffer.clone(), cx)
2610 })
2611 .await?;
2612
2613 Ok(buffer)
2614 }
2615}