1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100 pub images: Vec<LanguageModelImage>,
101}
102
103impl Message {
104 /// Returns whether the message contains any meaningful text that should be displayed
105 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
106 pub fn should_display_content(&self) -> bool {
107 self.segments.iter().all(|segment| segment.should_display())
108 }
109
110 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
111 if let Some(MessageSegment::Thinking {
112 text: segment,
113 signature: current_signature,
114 }) = self.segments.last_mut()
115 {
116 if let Some(signature) = signature {
117 *current_signature = Some(signature);
118 }
119 segment.push_str(text);
120 } else {
121 self.segments.push(MessageSegment::Thinking {
122 text: text.to_string(),
123 signature,
124 });
125 }
126 }
127
128 pub fn push_text(&mut self, text: &str) {
129 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
130 segment.push_str(text);
131 } else {
132 self.segments.push(MessageSegment::Text(text.to_string()));
133 }
134 }
135
136 pub fn to_string(&self) -> String {
137 let mut result = String::new();
138
139 if !self.context.is_empty() {
140 result.push_str(&self.context);
141 }
142
143 for segment in &self.segments {
144 match segment {
145 MessageSegment::Text(text) => result.push_str(text),
146 MessageSegment::Thinking { text, .. } => {
147 result.push_str("<think>\n");
148 result.push_str(text);
149 result.push_str("\n</think>");
150 }
151 MessageSegment::RedactedThinking(_) => {}
152 }
153 }
154
155 result
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub enum MessageSegment {
161 Text(String),
162 Thinking {
163 text: String,
164 signature: Option<String>,
165 },
166 RedactedThinking(Vec<u8>),
167}
168
169impl MessageSegment {
170 pub fn should_display(&self) -> bool {
171 match self {
172 Self::Text(text) => text.is_empty(),
173 Self::Thinking { text, .. } => text.is_empty(),
174 Self::RedactedThinking(_) => false,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ProjectSnapshot {
181 pub worktree_snapshots: Vec<WorktreeSnapshot>,
182 pub unsaved_buffer_paths: Vec<String>,
183 pub timestamp: DateTime<Utc>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct WorktreeSnapshot {
188 pub worktree_path: String,
189 pub git_state: Option<GitState>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GitState {
194 pub remote_url: Option<String>,
195 pub head_sha: Option<String>,
196 pub current_branch: Option<String>,
197 pub diff: Option<String>,
198}
199
200#[derive(Clone)]
201pub struct ThreadCheckpoint {
202 message_id: MessageId,
203 git_checkpoint: GitStoreCheckpoint,
204}
205
206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
207pub enum ThreadFeedback {
208 Positive,
209 Negative,
210}
211
212pub enum LastRestoreCheckpoint {
213 Pending {
214 message_id: MessageId,
215 },
216 Error {
217 message_id: MessageId,
218 error: String,
219 },
220}
221
222impl LastRestoreCheckpoint {
223 pub fn message_id(&self) -> MessageId {
224 match self {
225 LastRestoreCheckpoint::Pending { message_id } => *message_id,
226 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
227 }
228 }
229}
230
231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
232pub enum DetailedSummaryState {
233 #[default]
234 NotGenerated,
235 Generating {
236 message_id: MessageId,
237 },
238 Generated {
239 text: SharedString,
240 message_id: MessageId,
241 },
242}
243
244#[derive(Default)]
245pub struct TotalTokenUsage {
246 pub total: usize,
247 pub max: usize,
248}
249
250impl TotalTokenUsage {
251 pub fn ratio(&self) -> TokenUsageRatio {
252 #[cfg(debug_assertions)]
253 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
254 .unwrap_or("0.8".to_string())
255 .parse()
256 .unwrap();
257 #[cfg(not(debug_assertions))]
258 let warning_threshold: f32 = 0.8;
259
260 if self.total >= self.max {
261 TokenUsageRatio::Exceeded
262 } else if self.total as f32 / self.max as f32 >= warning_threshold {
263 TokenUsageRatio::Warning
264 } else {
265 TokenUsageRatio::Normal
266 }
267 }
268
269 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
270 TotalTokenUsage {
271 total: self.total + tokens,
272 max: self.max,
273 }
274 }
275}
276
277#[derive(Debug, Default, PartialEq, Eq)]
278pub enum TokenUsageRatio {
279 #[default]
280 Normal,
281 Warning,
282 Exceeded,
283}
284
285/// A thread of conversation with the LLM.
286pub struct Thread {
287 id: ThreadId,
288 updated_at: DateTime<Utc>,
289 summary: Option<SharedString>,
290 pending_summary: Task<Option<()>>,
291 detailed_summary_state: DetailedSummaryState,
292 messages: Vec<Message>,
293 next_message_id: MessageId,
294 last_prompt_id: PromptId,
295 context: BTreeMap<ContextId, AssistantContext>,
296 context_by_message: HashMap<MessageId, Vec<ContextId>>,
297 project_context: SharedProjectContext,
298 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
299 completion_count: usize,
300 pending_completions: Vec<PendingCompletion>,
301 project: Entity<Project>,
302 prompt_builder: Arc<PromptBuilder>,
303 tools: Entity<ToolWorkingSet>,
304 tool_use: ToolUseState,
305 action_log: Entity<ActionLog>,
306 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
307 pending_checkpoint: Option<ThreadCheckpoint>,
308 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
309 request_token_usage: Vec<TokenUsage>,
310 cumulative_token_usage: TokenUsage,
311 exceeded_window_error: Option<ExceededWindowError>,
312 feedback: Option<ThreadFeedback>,
313 message_feedback: HashMap<MessageId, ThreadFeedback>,
314 last_auto_capture_at: Option<Instant>,
315 request_callback: Option<
316 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
317 >,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct ExceededWindowError {
322 /// Model used when last message exceeded context window
323 model_id: LanguageModelId,
324 /// Token count including last message
325 token_count: usize,
326}
327
328impl Thread {
329 pub fn new(
330 project: Entity<Project>,
331 tools: Entity<ToolWorkingSet>,
332 prompt_builder: Arc<PromptBuilder>,
333 system_prompt: SharedProjectContext,
334 cx: &mut Context<Self>,
335 ) -> Self {
336 Self {
337 id: ThreadId::new(),
338 updated_at: Utc::now(),
339 summary: None,
340 pending_summary: Task::ready(None),
341 detailed_summary_state: DetailedSummaryState::NotGenerated,
342 messages: Vec::new(),
343 next_message_id: MessageId(0),
344 last_prompt_id: PromptId::new(),
345 context: BTreeMap::default(),
346 context_by_message: HashMap::default(),
347 project_context: system_prompt,
348 checkpoints_by_message: HashMap::default(),
349 completion_count: 0,
350 pending_completions: Vec::new(),
351 project: project.clone(),
352 prompt_builder,
353 tools: tools.clone(),
354 last_restore_checkpoint: None,
355 pending_checkpoint: None,
356 tool_use: ToolUseState::new(tools.clone()),
357 action_log: cx.new(|_| ActionLog::new(project.clone())),
358 initial_project_snapshot: {
359 let project_snapshot = Self::project_snapshot(project, cx);
360 cx.foreground_executor()
361 .spawn(async move { Some(project_snapshot.await) })
362 .shared()
363 },
364 request_token_usage: Vec::new(),
365 cumulative_token_usage: TokenUsage::default(),
366 exceeded_window_error: None,
367 feedback: None,
368 message_feedback: HashMap::default(),
369 last_auto_capture_at: None,
370 request_callback: None,
371 }
372 }
373
374 pub fn deserialize(
375 id: ThreadId,
376 serialized: SerializedThread,
377 project: Entity<Project>,
378 tools: Entity<ToolWorkingSet>,
379 prompt_builder: Arc<PromptBuilder>,
380 project_context: SharedProjectContext,
381 cx: &mut Context<Self>,
382 ) -> Self {
383 let next_message_id = MessageId(
384 serialized
385 .messages
386 .last()
387 .map(|message| message.id.0 + 1)
388 .unwrap_or(0),
389 );
390 let tool_use =
391 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
392
393 Self {
394 id,
395 updated_at: serialized.updated_at,
396 summary: Some(serialized.summary),
397 pending_summary: Task::ready(None),
398 detailed_summary_state: serialized.detailed_summary_state,
399 messages: serialized
400 .messages
401 .into_iter()
402 .map(|message| Message {
403 id: message.id,
404 role: message.role,
405 segments: message
406 .segments
407 .into_iter()
408 .map(|segment| match segment {
409 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
410 SerializedMessageSegment::Thinking { text, signature } => {
411 MessageSegment::Thinking { text, signature }
412 }
413 SerializedMessageSegment::RedactedThinking { data } => {
414 MessageSegment::RedactedThinking(data)
415 }
416 })
417 .collect(),
418 context: message.context,
419 images: Vec::new(),
420 })
421 .collect(),
422 next_message_id,
423 last_prompt_id: PromptId::new(),
424 context: BTreeMap::default(),
425 context_by_message: HashMap::default(),
426 project_context,
427 checkpoints_by_message: HashMap::default(),
428 completion_count: 0,
429 pending_completions: Vec::new(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 project: project.clone(),
433 prompt_builder,
434 tools,
435 tool_use,
436 action_log: cx.new(|_| ActionLog::new(project)),
437 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
438 request_token_usage: serialized.request_token_usage,
439 cumulative_token_usage: serialized.cumulative_token_usage,
440 exceeded_window_error: None,
441 feedback: None,
442 message_feedback: HashMap::default(),
443 last_auto_capture_at: None,
444 request_callback: None,
445 }
446 }
447
448 pub fn set_request_callback(
449 &mut self,
450 callback: impl 'static
451 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
452 ) {
453 self.request_callback = Some(Box::new(callback));
454 }
455
456 pub fn id(&self) -> &ThreadId {
457 &self.id
458 }
459
460 pub fn is_empty(&self) -> bool {
461 self.messages.is_empty()
462 }
463
464 pub fn updated_at(&self) -> DateTime<Utc> {
465 self.updated_at
466 }
467
468 pub fn touch_updated_at(&mut self) {
469 self.updated_at = Utc::now();
470 }
471
472 pub fn advance_prompt_id(&mut self) {
473 self.last_prompt_id = PromptId::new();
474 }
475
476 pub fn summary(&self) -> Option<SharedString> {
477 self.summary.clone()
478 }
479
480 pub fn project_context(&self) -> SharedProjectContext {
481 self.project_context.clone()
482 }
483
484 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
485
486 pub fn summary_or_default(&self) -> SharedString {
487 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
488 }
489
490 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
491 let Some(current_summary) = &self.summary else {
492 // Don't allow setting summary until generated
493 return;
494 };
495
496 let mut new_summary = new_summary.into();
497
498 if new_summary.is_empty() {
499 new_summary = Self::DEFAULT_SUMMARY;
500 }
501
502 if current_summary != &new_summary {
503 self.summary = Some(new_summary);
504 cx.emit(ThreadEvent::SummaryChanged);
505 }
506 }
507
508 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
509 self.latest_detailed_summary()
510 .unwrap_or_else(|| self.text().into())
511 }
512
513 fn latest_detailed_summary(&self) -> Option<SharedString> {
514 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
515 Some(text.clone())
516 } else {
517 None
518 }
519 }
520
521 pub fn message(&self, id: MessageId) -> Option<&Message> {
522 self.messages.iter().find(|message| message.id == id)
523 }
524
525 pub fn messages(&self) -> impl Iterator<Item = &Message> {
526 self.messages.iter()
527 }
528
529 pub fn is_generating(&self) -> bool {
530 !self.pending_completions.is_empty() || !self.all_tools_finished()
531 }
532
533 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
534 &self.tools
535 }
536
537 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
538 self.tool_use
539 .pending_tool_uses()
540 .into_iter()
541 .find(|tool_use| &tool_use.id == id)
542 }
543
544 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
545 self.tool_use
546 .pending_tool_uses()
547 .into_iter()
548 .filter(|tool_use| tool_use.status.needs_confirmation())
549 }
550
551 pub fn has_pending_tool_uses(&self) -> bool {
552 !self.tool_use.pending_tool_uses().is_empty()
553 }
554
555 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
556 self.checkpoints_by_message.get(&id).cloned()
557 }
558
559 pub fn restore_checkpoint(
560 &mut self,
561 checkpoint: ThreadCheckpoint,
562 cx: &mut Context<Self>,
563 ) -> Task<Result<()>> {
564 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
565 message_id: checkpoint.message_id,
566 });
567 cx.emit(ThreadEvent::CheckpointChanged);
568 cx.notify();
569
570 let git_store = self.project().read(cx).git_store().clone();
571 let restore = git_store.update(cx, |git_store, cx| {
572 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
573 });
574
575 cx.spawn(async move |this, cx| {
576 let result = restore.await;
577 this.update(cx, |this, cx| {
578 if let Err(err) = result.as_ref() {
579 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
580 message_id: checkpoint.message_id,
581 error: err.to_string(),
582 });
583 } else {
584 this.truncate(checkpoint.message_id, cx);
585 this.last_restore_checkpoint = None;
586 }
587 this.pending_checkpoint = None;
588 cx.emit(ThreadEvent::CheckpointChanged);
589 cx.notify();
590 })?;
591 result
592 })
593 }
594
595 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
596 let pending_checkpoint = if self.is_generating() {
597 return;
598 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
599 checkpoint
600 } else {
601 return;
602 };
603
604 let git_store = self.project.read(cx).git_store().clone();
605 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
606 cx.spawn(async move |this, cx| match final_checkpoint.await {
607 Ok(final_checkpoint) => {
608 let equal = git_store
609 .update(cx, |store, cx| {
610 store.compare_checkpoints(
611 pending_checkpoint.git_checkpoint.clone(),
612 final_checkpoint.clone(),
613 cx,
614 )
615 })?
616 .await
617 .unwrap_or(false);
618
619 if equal {
620 git_store
621 .update(cx, |store, cx| {
622 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
623 })?
624 .detach();
625 } else {
626 this.update(cx, |this, cx| {
627 this.insert_checkpoint(pending_checkpoint, cx)
628 })?;
629 }
630
631 git_store
632 .update(cx, |store, cx| {
633 store.delete_checkpoint(final_checkpoint, cx)
634 })?
635 .detach();
636
637 Ok(())
638 }
639 Err(_) => this.update(cx, |this, cx| {
640 this.insert_checkpoint(pending_checkpoint, cx)
641 }),
642 })
643 .detach();
644 }
645
646 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
647 self.checkpoints_by_message
648 .insert(checkpoint.message_id, checkpoint);
649 cx.emit(ThreadEvent::CheckpointChanged);
650 cx.notify();
651 }
652
653 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
654 self.last_restore_checkpoint.as_ref()
655 }
656
657 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
658 let Some(message_ix) = self
659 .messages
660 .iter()
661 .rposition(|message| message.id == message_id)
662 else {
663 return;
664 };
665 for deleted_message in self.messages.drain(message_ix..) {
666 self.context_by_message.remove(&deleted_message.id);
667 self.checkpoints_by_message.remove(&deleted_message.id);
668 }
669 cx.notify();
670 }
671
672 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
673 self.context_by_message
674 .get(&id)
675 .into_iter()
676 .flat_map(|context| {
677 context
678 .iter()
679 .filter_map(|context_id| self.context.get(&context_id))
680 })
681 }
682
683 /// Returns whether all of the tool uses have finished running.
684 pub fn all_tools_finished(&self) -> bool {
685 // If the only pending tool uses left are the ones with errors, then
686 // that means that we've finished running all of the pending tools.
687 self.tool_use
688 .pending_tool_uses()
689 .iter()
690 .all(|tool_use| tool_use.status.is_error())
691 }
692
693 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
694 self.tool_use.tool_uses_for_message(id, cx)
695 }
696
697 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
698 self.tool_use.tool_results_for_message(id)
699 }
700
701 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
702 self.tool_use.tool_result(id)
703 }
704
705 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
706 Some(&self.tool_use.tool_result(id)?.content)
707 }
708
709 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
710 self.tool_use.tool_result_card(id).cloned()
711 }
712
713 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
714 self.tool_use.message_has_tool_results(message_id)
715 }
716
717 /// Filter out contexts that have already been included in previous messages
718 pub fn filter_new_context<'a>(
719 &self,
720 context: impl Iterator<Item = &'a AssistantContext>,
721 ) -> impl Iterator<Item = &'a AssistantContext> {
722 context.filter(|ctx| self.is_context_new(ctx))
723 }
724
725 fn is_context_new(&self, context: &AssistantContext) -> bool {
726 !self.context.contains_key(&context.id())
727 }
728
729 pub fn insert_user_message(
730 &mut self,
731 text: impl Into<String>,
732 context: Vec<AssistantContext>,
733 git_checkpoint: Option<GitStoreCheckpoint>,
734 cx: &mut Context<Self>,
735 ) -> MessageId {
736 let text = text.into();
737
738 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
739
740 let new_context: Vec<_> = context
741 .into_iter()
742 .filter(|ctx| self.is_context_new(ctx))
743 .collect();
744
745 if !new_context.is_empty() {
746 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
747 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
748 message.context = context_string;
749 }
750 }
751
752 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
753 message.images = new_context
754 .iter()
755 .filter_map(|context| {
756 if let AssistantContext::Image(image_context) = context {
757 image_context.image_task.clone().now_or_never().flatten()
758 } else {
759 None
760 }
761 })
762 .collect::<Vec<_>>();
763 }
764
765 self.action_log.update(cx, |log, cx| {
766 // Track all buffers added as context
767 for ctx in &new_context {
768 match ctx {
769 AssistantContext::File(file_ctx) => {
770 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
771 }
772 AssistantContext::Directory(dir_ctx) => {
773 for context_buffer in &dir_ctx.context_buffers {
774 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
775 }
776 }
777 AssistantContext::Symbol(symbol_ctx) => {
778 log.buffer_added_as_context(
779 symbol_ctx.context_symbol.buffer.clone(),
780 cx,
781 );
782 }
783 AssistantContext::Selection(selection_context) => {
784 log.buffer_added_as_context(
785 selection_context.context_buffer.buffer.clone(),
786 cx,
787 );
788 }
789 AssistantContext::FetchedUrl(_)
790 | AssistantContext::Thread(_)
791 | AssistantContext::Rules(_)
792 | AssistantContext::Image(_) => {}
793 }
794 }
795 });
796 }
797
798 let context_ids = new_context
799 .iter()
800 .map(|context| context.id())
801 .collect::<Vec<_>>();
802 self.context.extend(
803 new_context
804 .into_iter()
805 .map(|context| (context.id(), context)),
806 );
807 self.context_by_message.insert(message_id, context_ids);
808
809 if let Some(git_checkpoint) = git_checkpoint {
810 self.pending_checkpoint = Some(ThreadCheckpoint {
811 message_id,
812 git_checkpoint,
813 });
814 }
815
816 self.auto_capture_telemetry(cx);
817
818 message_id
819 }
820
821 pub fn insert_message(
822 &mut self,
823 role: Role,
824 segments: Vec<MessageSegment>,
825 cx: &mut Context<Self>,
826 ) -> MessageId {
827 let id = self.next_message_id.post_inc();
828 self.messages.push(Message {
829 id,
830 role,
831 segments,
832 context: String::new(),
833 images: Vec::new(),
834 });
835 self.touch_updated_at();
836 cx.emit(ThreadEvent::MessageAdded(id));
837 id
838 }
839
840 pub fn edit_message(
841 &mut self,
842 id: MessageId,
843 new_role: Role,
844 new_segments: Vec<MessageSegment>,
845 cx: &mut Context<Self>,
846 ) -> bool {
847 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
848 return false;
849 };
850 message.role = new_role;
851 message.segments = new_segments;
852 self.touch_updated_at();
853 cx.emit(ThreadEvent::MessageEdited(id));
854 true
855 }
856
857 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
858 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
859 return false;
860 };
861 self.messages.remove(index);
862 self.context_by_message.remove(&id);
863 self.touch_updated_at();
864 cx.emit(ThreadEvent::MessageDeleted(id));
865 true
866 }
867
868 /// Returns the representation of this [`Thread`] in a textual form.
869 ///
870 /// This is the representation we use when attaching a thread as context to another thread.
871 pub fn text(&self) -> String {
872 let mut text = String::new();
873
874 for message in &self.messages {
875 text.push_str(match message.role {
876 language_model::Role::User => "User:",
877 language_model::Role::Assistant => "Assistant:",
878 language_model::Role::System => "System:",
879 });
880 text.push('\n');
881
882 for segment in &message.segments {
883 match segment {
884 MessageSegment::Text(content) => text.push_str(content),
885 MessageSegment::Thinking { text: content, .. } => {
886 text.push_str(&format!("<think>{}</think>", content))
887 }
888 MessageSegment::RedactedThinking(_) => {}
889 }
890 }
891 text.push('\n');
892 }
893
894 text
895 }
896
897 /// Serializes this thread into a format for storage or telemetry.
898 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
899 let initial_project_snapshot = self.initial_project_snapshot.clone();
900 cx.spawn(async move |this, cx| {
901 let initial_project_snapshot = initial_project_snapshot.await;
902 this.read_with(cx, |this, cx| SerializedThread {
903 version: SerializedThread::VERSION.to_string(),
904 summary: this.summary_or_default(),
905 updated_at: this.updated_at(),
906 messages: this
907 .messages()
908 .map(|message| SerializedMessage {
909 id: message.id,
910 role: message.role,
911 segments: message
912 .segments
913 .iter()
914 .map(|segment| match segment {
915 MessageSegment::Text(text) => {
916 SerializedMessageSegment::Text { text: text.clone() }
917 }
918 MessageSegment::Thinking { text, signature } => {
919 SerializedMessageSegment::Thinking {
920 text: text.clone(),
921 signature: signature.clone(),
922 }
923 }
924 MessageSegment::RedactedThinking(data) => {
925 SerializedMessageSegment::RedactedThinking {
926 data: data.clone(),
927 }
928 }
929 })
930 .collect(),
931 tool_uses: this
932 .tool_uses_for_message(message.id, cx)
933 .into_iter()
934 .map(|tool_use| SerializedToolUse {
935 id: tool_use.id,
936 name: tool_use.name,
937 input: tool_use.input,
938 })
939 .collect(),
940 tool_results: this
941 .tool_results_for_message(message.id)
942 .into_iter()
943 .map(|tool_result| SerializedToolResult {
944 tool_use_id: tool_result.tool_use_id.clone(),
945 is_error: tool_result.is_error,
946 content: tool_result.content.clone(),
947 })
948 .collect(),
949 context: message.context.clone(),
950 })
951 .collect(),
952 initial_project_snapshot,
953 cumulative_token_usage: this.cumulative_token_usage,
954 request_token_usage: this.request_token_usage.clone(),
955 detailed_summary_state: this.detailed_summary_state.clone(),
956 exceeded_window_error: this.exceeded_window_error.clone(),
957 })
958 })
959 }
960
961 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
962 let mut request = self.to_completion_request(cx);
963 if model.supports_tools() {
964 request.tools = {
965 let mut tools = Vec::new();
966 tools.extend(
967 self.tools()
968 .read(cx)
969 .enabled_tools(cx)
970 .into_iter()
971 .filter_map(|tool| {
972 // Skip tools that cannot be supported
973 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
974 Some(LanguageModelRequestTool {
975 name: tool.name(),
976 description: tool.description(),
977 input_schema,
978 })
979 }),
980 );
981
982 tools
983 };
984 }
985
986 self.stream_completion(request, model, cx);
987 }
988
989 pub fn used_tools_since_last_user_message(&self) -> bool {
990 for message in self.messages.iter().rev() {
991 if self.tool_use.message_has_tool_results(message.id) {
992 return true;
993 } else if message.role == Role::User {
994 return false;
995 }
996 }
997
998 false
999 }
1000
1001 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1002 let mut request = LanguageModelRequest {
1003 thread_id: Some(self.id.to_string()),
1004 prompt_id: Some(self.last_prompt_id.to_string()),
1005 messages: vec![],
1006 tools: Vec::new(),
1007 stop: Vec::new(),
1008 temperature: None,
1009 };
1010
1011 if let Some(project_context) = self.project_context.borrow().as_ref() {
1012 match self
1013 .prompt_builder
1014 .generate_assistant_system_prompt(project_context)
1015 {
1016 Err(err) => {
1017 let message = format!("{err:?}").into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024 Ok(system_prompt) => {
1025 request.messages.push(LanguageModelRequestMessage {
1026 role: Role::System,
1027 content: vec![MessageContent::Text(system_prompt)],
1028 cache: true,
1029 });
1030 }
1031 }
1032 } else {
1033 let message = "Context for system prompt unexpectedly not ready.".into();
1034 log::error!("{message}");
1035 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1036 header: "Error generating system prompt".into(),
1037 message,
1038 }));
1039 }
1040
1041 for message in &self.messages {
1042 let mut request_message = LanguageModelRequestMessage {
1043 role: message.role,
1044 content: Vec::new(),
1045 cache: false,
1046 };
1047
1048 self.tool_use
1049 .attach_tool_results(message.id, &mut request_message);
1050
1051 if !message.context.is_empty() {
1052 request_message
1053 .content
1054 .push(MessageContent::Text(message.context.to_string()));
1055 }
1056
1057 if !message.images.is_empty() {
1058 // Some providers only support image parts after an initial text part
1059 if request_message.content.is_empty() {
1060 request_message
1061 .content
1062 .push(MessageContent::Text("Images attached by user:".to_string()));
1063 }
1064
1065 for image in &message.images {
1066 request_message
1067 .content
1068 .push(MessageContent::Image(image.clone()))
1069 }
1070 }
1071
1072 for segment in &message.segments {
1073 match segment {
1074 MessageSegment::Text(text) => {
1075 if !text.is_empty() {
1076 request_message
1077 .content
1078 .push(MessageContent::Text(text.into()));
1079 }
1080 }
1081 MessageSegment::Thinking { text, signature } => {
1082 if !text.is_empty() {
1083 request_message.content.push(MessageContent::Thinking {
1084 text: text.into(),
1085 signature: signature.clone(),
1086 });
1087 }
1088 }
1089 MessageSegment::RedactedThinking(data) => {
1090 request_message
1091 .content
1092 .push(MessageContent::RedactedThinking(data.clone()));
1093 }
1094 };
1095 }
1096
1097 self.tool_use
1098 .attach_tool_uses(message.id, &mut request_message);
1099
1100 request.messages.push(request_message);
1101 }
1102
1103 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1104 if let Some(last) = request.messages.last_mut() {
1105 last.cache = true;
1106 }
1107
1108 self.attached_tracked_files_state(&mut request.messages, cx);
1109
1110 request
1111 }
1112
1113 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1114 let mut request = LanguageModelRequest {
1115 thread_id: None,
1116 prompt_id: None,
1117 messages: vec![],
1118 tools: Vec::new(),
1119 stop: Vec::new(),
1120 temperature: None,
1121 };
1122
1123 for message in &self.messages {
1124 let mut request_message = LanguageModelRequestMessage {
1125 role: message.role,
1126 content: Vec::new(),
1127 cache: false,
1128 };
1129
1130 // Skip tool results during summarization.
1131 if self.tool_use.message_has_tool_results(message.id) {
1132 continue;
1133 }
1134
1135 for segment in &message.segments {
1136 match segment {
1137 MessageSegment::Text(text) => request_message
1138 .content
1139 .push(MessageContent::Text(text.clone())),
1140 MessageSegment::Thinking { .. } => {}
1141 MessageSegment::RedactedThinking(_) => {}
1142 }
1143 }
1144
1145 if request_message.content.is_empty() {
1146 continue;
1147 }
1148
1149 request.messages.push(request_message);
1150 }
1151
1152 request.messages.push(LanguageModelRequestMessage {
1153 role: Role::User,
1154 content: vec![MessageContent::Text(added_user_message)],
1155 cache: false,
1156 });
1157
1158 request
1159 }
1160
1161 fn attached_tracked_files_state(
1162 &self,
1163 messages: &mut Vec<LanguageModelRequestMessage>,
1164 cx: &App,
1165 ) {
1166 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1167
1168 let mut stale_message = String::new();
1169
1170 let action_log = self.action_log.read(cx);
1171
1172 for stale_file in action_log.stale_buffers(cx) {
1173 let Some(file) = stale_file.read(cx).file() else {
1174 continue;
1175 };
1176
1177 if stale_message.is_empty() {
1178 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1179 }
1180
1181 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1182 }
1183
1184 let mut content = Vec::with_capacity(2);
1185
1186 if !stale_message.is_empty() {
1187 content.push(stale_message.into());
1188 }
1189
1190 if !content.is_empty() {
1191 let context_message = LanguageModelRequestMessage {
1192 role: Role::User,
1193 content,
1194 cache: false,
1195 };
1196
1197 messages.push(context_message);
1198 }
1199 }
1200
1201 pub fn stream_completion(
1202 &mut self,
1203 request: LanguageModelRequest,
1204 model: Arc<dyn LanguageModel>,
1205 cx: &mut Context<Self>,
1206 ) {
1207 let pending_completion_id = post_inc(&mut self.completion_count);
1208 let mut request_callback_parameters = if self.request_callback.is_some() {
1209 Some((request.clone(), Vec::new()))
1210 } else {
1211 None
1212 };
1213 let prompt_id = self.last_prompt_id.clone();
1214 let tool_use_metadata = ToolUseMetadata {
1215 model: model.clone(),
1216 thread_id: self.id.clone(),
1217 prompt_id: prompt_id.clone(),
1218 };
1219
1220 let task = cx.spawn(async move |thread, cx| {
1221 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1222 let initial_token_usage =
1223 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1224 let stream_completion = async {
1225 let (mut events, usage) = stream_completion_future.await?;
1226
1227 let mut stop_reason = StopReason::EndTurn;
1228 let mut current_token_usage = TokenUsage::default();
1229
1230 if let Some(usage) = usage {
1231 thread
1232 .update(cx, |_thread, cx| {
1233 cx.emit(ThreadEvent::UsageUpdated(usage));
1234 })
1235 .ok();
1236 }
1237
1238 while let Some(event) = events.next().await {
1239 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1240 response_events
1241 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1242 }
1243
1244 let event = event?;
1245
1246 thread.update(cx, |thread, cx| {
1247 match event {
1248 LanguageModelCompletionEvent::StartMessage { .. } => {
1249 thread.insert_message(
1250 Role::Assistant,
1251 vec![MessageSegment::Text(String::new())],
1252 cx,
1253 );
1254 }
1255 LanguageModelCompletionEvent::Stop(reason) => {
1256 stop_reason = reason;
1257 }
1258 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1259 thread.update_token_usage_at_last_message(token_usage);
1260 thread.cumulative_token_usage = thread.cumulative_token_usage
1261 + token_usage
1262 - current_token_usage;
1263 current_token_usage = token_usage;
1264 }
1265 LanguageModelCompletionEvent::Text(chunk) => {
1266 if let Some(last_message) = thread.messages.last_mut() {
1267 if last_message.role == Role::Assistant {
1268 last_message.push_text(&chunk);
1269 cx.emit(ThreadEvent::StreamedAssistantText(
1270 last_message.id,
1271 chunk,
1272 ));
1273 } else {
1274 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1275 // of a new Assistant response.
1276 //
1277 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1278 // will result in duplicating the text of the chunk in the rendered Markdown.
1279 thread.insert_message(
1280 Role::Assistant,
1281 vec![MessageSegment::Text(chunk.to_string())],
1282 cx,
1283 );
1284 };
1285 }
1286 }
1287 LanguageModelCompletionEvent::Thinking {
1288 text: chunk,
1289 signature,
1290 } => {
1291 if let Some(last_message) = thread.messages.last_mut() {
1292 if last_message.role == Role::Assistant {
1293 last_message.push_thinking(&chunk, signature);
1294 cx.emit(ThreadEvent::StreamedAssistantThinking(
1295 last_message.id,
1296 chunk,
1297 ));
1298 } else {
1299 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1300 // of a new Assistant response.
1301 //
1302 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1303 // will result in duplicating the text of the chunk in the rendered Markdown.
1304 thread.insert_message(
1305 Role::Assistant,
1306 vec![MessageSegment::Thinking {
1307 text: chunk.to_string(),
1308 signature,
1309 }],
1310 cx,
1311 );
1312 };
1313 }
1314 }
1315 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1316 let last_assistant_message_id = thread
1317 .messages
1318 .iter_mut()
1319 .rfind(|message| message.role == Role::Assistant)
1320 .map(|message| message.id)
1321 .unwrap_or_else(|| {
1322 thread.insert_message(Role::Assistant, vec![], cx)
1323 });
1324
1325 thread.tool_use.request_tool_use(
1326 last_assistant_message_id,
1327 tool_use,
1328 tool_use_metadata.clone(),
1329 cx,
1330 );
1331 }
1332 }
1333
1334 thread.touch_updated_at();
1335 cx.emit(ThreadEvent::StreamedCompletion);
1336 cx.notify();
1337
1338 thread.auto_capture_telemetry(cx);
1339 })?;
1340
1341 smol::future::yield_now().await;
1342 }
1343
1344 thread.update(cx, |thread, cx| {
1345 thread
1346 .pending_completions
1347 .retain(|completion| completion.id != pending_completion_id);
1348
1349 // If there is a response without tool use, summarize the message. Otherwise,
1350 // allow two tool uses before summarizing.
1351 if thread.summary.is_none()
1352 && thread.messages.len() >= 2
1353 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1354 {
1355 thread.summarize(cx);
1356 }
1357 })?;
1358
1359 anyhow::Ok(stop_reason)
1360 };
1361
1362 let result = stream_completion.await;
1363
1364 thread
1365 .update(cx, |thread, cx| {
1366 thread.finalize_pending_checkpoint(cx);
1367 match result.as_ref() {
1368 Ok(stop_reason) => match stop_reason {
1369 StopReason::ToolUse => {
1370 let tool_uses = thread.use_pending_tools(cx);
1371 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1372 }
1373 StopReason::EndTurn => {}
1374 StopReason::MaxTokens => {}
1375 },
1376 Err(error) => {
1377 if error.is::<PaymentRequiredError>() {
1378 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1379 } else if error.is::<MaxMonthlySpendReachedError>() {
1380 cx.emit(ThreadEvent::ShowError(
1381 ThreadError::MaxMonthlySpendReached,
1382 ));
1383 } else if let Some(error) =
1384 error.downcast_ref::<ModelRequestLimitReachedError>()
1385 {
1386 cx.emit(ThreadEvent::ShowError(
1387 ThreadError::ModelRequestLimitReached { plan: error.plan },
1388 ));
1389 } else if let Some(known_error) =
1390 error.downcast_ref::<LanguageModelKnownError>()
1391 {
1392 match known_error {
1393 LanguageModelKnownError::ContextWindowLimitExceeded {
1394 tokens,
1395 } => {
1396 thread.exceeded_window_error = Some(ExceededWindowError {
1397 model_id: model.id(),
1398 token_count: *tokens,
1399 });
1400 cx.notify();
1401 }
1402 }
1403 } else {
1404 let error_message = error
1405 .chain()
1406 .map(|err| err.to_string())
1407 .collect::<Vec<_>>()
1408 .join("\n");
1409 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1410 header: "Error interacting with language model".into(),
1411 message: SharedString::from(error_message.clone()),
1412 }));
1413 }
1414
1415 thread.cancel_last_completion(cx);
1416 }
1417 }
1418 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1419
1420 if let Some((request_callback, (request, response_events))) = thread
1421 .request_callback
1422 .as_mut()
1423 .zip(request_callback_parameters.as_ref())
1424 {
1425 request_callback(request, response_events);
1426 }
1427
1428 thread.auto_capture_telemetry(cx);
1429
1430 if let Ok(initial_usage) = initial_token_usage {
1431 let usage = thread.cumulative_token_usage - initial_usage;
1432
1433 telemetry::event!(
1434 "Assistant Thread Completion",
1435 thread_id = thread.id().to_string(),
1436 prompt_id = prompt_id,
1437 model = model.telemetry_id(),
1438 model_provider = model.provider_id().to_string(),
1439 input_tokens = usage.input_tokens,
1440 output_tokens = usage.output_tokens,
1441 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1442 cache_read_input_tokens = usage.cache_read_input_tokens,
1443 );
1444 }
1445 })
1446 .ok();
1447 });
1448
1449 self.pending_completions.push(PendingCompletion {
1450 id: pending_completion_id,
1451 _task: task,
1452 });
1453 }
1454
1455 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1456 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1457 return;
1458 };
1459
1460 if !model.provider.is_authenticated(cx) {
1461 return;
1462 }
1463
1464 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1465 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1466 If the conversation is about a specific subject, include it in the title. \
1467 Be descriptive. DO NOT speak in the first person.";
1468
1469 let request = self.to_summarize_request(added_user_message.into());
1470
1471 self.pending_summary = cx.spawn(async move |this, cx| {
1472 async move {
1473 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1474 let (mut messages, usage) = stream.await?;
1475
1476 if let Some(usage) = usage {
1477 this.update(cx, |_thread, cx| {
1478 cx.emit(ThreadEvent::UsageUpdated(usage));
1479 })
1480 .ok();
1481 }
1482
1483 let mut new_summary = String::new();
1484 while let Some(message) = messages.stream.next().await {
1485 let text = message?;
1486 let mut lines = text.lines();
1487 new_summary.extend(lines.next());
1488
1489 // Stop if the LLM generated multiple lines.
1490 if lines.next().is_some() {
1491 break;
1492 }
1493 }
1494
1495 this.update(cx, |this, cx| {
1496 if !new_summary.is_empty() {
1497 this.summary = Some(new_summary.into());
1498 }
1499
1500 cx.emit(ThreadEvent::SummaryGenerated);
1501 })?;
1502
1503 anyhow::Ok(())
1504 }
1505 .log_err()
1506 .await
1507 });
1508 }
1509
1510 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1511 let last_message_id = self.messages.last().map(|message| message.id)?;
1512
1513 match &self.detailed_summary_state {
1514 DetailedSummaryState::Generating { message_id, .. }
1515 | DetailedSummaryState::Generated { message_id, .. }
1516 if *message_id == last_message_id =>
1517 {
1518 // Already up-to-date
1519 return None;
1520 }
1521 _ => {}
1522 }
1523
1524 let ConfiguredModel { model, provider } =
1525 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1526
1527 if !provider.is_authenticated(cx) {
1528 return None;
1529 }
1530
1531 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1532 1. A brief overview of what was discussed\n\
1533 2. Key facts or information discovered\n\
1534 3. Outcomes or conclusions reached\n\
1535 4. Any action items or next steps if any\n\
1536 Format it in Markdown with headings and bullet points.";
1537
1538 let request = self.to_summarize_request(added_user_message.into());
1539
1540 let task = cx.spawn(async move |thread, cx| {
1541 let stream = model.stream_completion_text(request, &cx);
1542 let Some(mut messages) = stream.await.log_err() else {
1543 thread
1544 .update(cx, |this, _cx| {
1545 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1546 })
1547 .log_err();
1548
1549 return;
1550 };
1551
1552 let mut new_detailed_summary = String::new();
1553
1554 while let Some(chunk) = messages.stream.next().await {
1555 if let Some(chunk) = chunk.log_err() {
1556 new_detailed_summary.push_str(&chunk);
1557 }
1558 }
1559
1560 thread
1561 .update(cx, |this, _cx| {
1562 this.detailed_summary_state = DetailedSummaryState::Generated {
1563 text: new_detailed_summary.into(),
1564 message_id: last_message_id,
1565 };
1566 })
1567 .log_err();
1568 });
1569
1570 self.detailed_summary_state = DetailedSummaryState::Generating {
1571 message_id: last_message_id,
1572 };
1573
1574 Some(task)
1575 }
1576
1577 pub fn is_generating_detailed_summary(&self) -> bool {
1578 matches!(
1579 self.detailed_summary_state,
1580 DetailedSummaryState::Generating { .. }
1581 )
1582 }
1583
1584 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1585 self.auto_capture_telemetry(cx);
1586 let request = self.to_completion_request(cx);
1587 let messages = Arc::new(request.messages);
1588 let pending_tool_uses = self
1589 .tool_use
1590 .pending_tool_uses()
1591 .into_iter()
1592 .filter(|tool_use| tool_use.status.is_idle())
1593 .cloned()
1594 .collect::<Vec<_>>();
1595
1596 for tool_use in pending_tool_uses.iter() {
1597 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1598 if tool.needs_confirmation(&tool_use.input, cx)
1599 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1600 {
1601 self.tool_use.confirm_tool_use(
1602 tool_use.id.clone(),
1603 tool_use.ui_text.clone(),
1604 tool_use.input.clone(),
1605 messages.clone(),
1606 tool,
1607 );
1608 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1609 } else {
1610 self.run_tool(
1611 tool_use.id.clone(),
1612 tool_use.ui_text.clone(),
1613 tool_use.input.clone(),
1614 &messages,
1615 tool,
1616 cx,
1617 );
1618 }
1619 }
1620 }
1621
1622 pending_tool_uses
1623 }
1624
1625 pub fn run_tool(
1626 &mut self,
1627 tool_use_id: LanguageModelToolUseId,
1628 ui_text: impl Into<SharedString>,
1629 input: serde_json::Value,
1630 messages: &[LanguageModelRequestMessage],
1631 tool: Arc<dyn Tool>,
1632 cx: &mut Context<Thread>,
1633 ) {
1634 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1635 self.tool_use
1636 .run_pending_tool(tool_use_id, ui_text.into(), task);
1637 }
1638
1639 fn spawn_tool_use(
1640 &mut self,
1641 tool_use_id: LanguageModelToolUseId,
1642 messages: &[LanguageModelRequestMessage],
1643 input: serde_json::Value,
1644 tool: Arc<dyn Tool>,
1645 cx: &mut Context<Thread>,
1646 ) -> Task<()> {
1647 let tool_name: Arc<str> = tool.name().into();
1648
1649 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1650 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1651 } else {
1652 tool.run(
1653 input,
1654 messages,
1655 self.project.clone(),
1656 self.action_log.clone(),
1657 cx,
1658 )
1659 };
1660
1661 // Store the card separately if it exists
1662 if let Some(card) = tool_result.card.clone() {
1663 self.tool_use
1664 .insert_tool_result_card(tool_use_id.clone(), card);
1665 }
1666
1667 cx.spawn({
1668 async move |thread: WeakEntity<Thread>, cx| {
1669 let output = tool_result.output.await;
1670
1671 thread
1672 .update(cx, |thread, cx| {
1673 let pending_tool_use = thread.tool_use.insert_tool_output(
1674 tool_use_id.clone(),
1675 tool_name,
1676 output,
1677 cx,
1678 );
1679 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1680 })
1681 .ok();
1682 }
1683 })
1684 }
1685
1686 fn tool_finished(
1687 &mut self,
1688 tool_use_id: LanguageModelToolUseId,
1689 pending_tool_use: Option<PendingToolUse>,
1690 canceled: bool,
1691 cx: &mut Context<Self>,
1692 ) {
1693 if self.all_tools_finished() {
1694 let model_registry = LanguageModelRegistry::read_global(cx);
1695 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1696 self.attach_tool_results(cx);
1697 if !canceled {
1698 self.send_to_model(model, cx);
1699 }
1700 }
1701 }
1702
1703 cx.emit(ThreadEvent::ToolFinished {
1704 tool_use_id,
1705 pending_tool_use,
1706 });
1707 }
1708
1709 /// Insert an empty message to be populated with tool results upon send.
1710 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1711 // Tool results are assumed to be waiting on the next message id, so they will populate
1712 // this empty message before sending to model. Would prefer this to be more straightforward.
1713 self.insert_message(Role::User, vec![], cx);
1714 self.auto_capture_telemetry(cx);
1715 }
1716
1717 /// Cancels the last pending completion, if there are any pending.
1718 ///
1719 /// Returns whether a completion was canceled.
1720 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1721 let canceled = if self.pending_completions.pop().is_some() {
1722 true
1723 } else {
1724 let mut canceled = false;
1725 for pending_tool_use in self.tool_use.cancel_pending() {
1726 canceled = true;
1727 self.tool_finished(
1728 pending_tool_use.id.clone(),
1729 Some(pending_tool_use),
1730 true,
1731 cx,
1732 );
1733 }
1734 canceled
1735 };
1736 self.finalize_pending_checkpoint(cx);
1737 canceled
1738 }
1739
1740 pub fn feedback(&self) -> Option<ThreadFeedback> {
1741 self.feedback
1742 }
1743
1744 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1745 self.message_feedback.get(&message_id).copied()
1746 }
1747
1748 pub fn report_message_feedback(
1749 &mut self,
1750 message_id: MessageId,
1751 feedback: ThreadFeedback,
1752 cx: &mut Context<Self>,
1753 ) -> Task<Result<()>> {
1754 if self.message_feedback.get(&message_id) == Some(&feedback) {
1755 return Task::ready(Ok(()));
1756 }
1757
1758 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1759 let serialized_thread = self.serialize(cx);
1760 let thread_id = self.id().clone();
1761 let client = self.project.read(cx).client();
1762
1763 let enabled_tool_names: Vec<String> = self
1764 .tools()
1765 .read(cx)
1766 .enabled_tools(cx)
1767 .iter()
1768 .map(|tool| tool.name().to_string())
1769 .collect();
1770
1771 self.message_feedback.insert(message_id, feedback);
1772
1773 cx.notify();
1774
1775 let message_content = self
1776 .message(message_id)
1777 .map(|msg| msg.to_string())
1778 .unwrap_or_default();
1779
1780 cx.background_spawn(async move {
1781 let final_project_snapshot = final_project_snapshot.await;
1782 let serialized_thread = serialized_thread.await?;
1783 let thread_data =
1784 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1785
1786 let rating = match feedback {
1787 ThreadFeedback::Positive => "positive",
1788 ThreadFeedback::Negative => "negative",
1789 };
1790 telemetry::event!(
1791 "Assistant Thread Rated",
1792 rating,
1793 thread_id,
1794 enabled_tool_names,
1795 message_id = message_id.0,
1796 message_content,
1797 thread_data,
1798 final_project_snapshot
1799 );
1800 client.telemetry().flush_events();
1801
1802 Ok(())
1803 })
1804 }
1805
1806 pub fn report_feedback(
1807 &mut self,
1808 feedback: ThreadFeedback,
1809 cx: &mut Context<Self>,
1810 ) -> Task<Result<()>> {
1811 let last_assistant_message_id = self
1812 .messages
1813 .iter()
1814 .rev()
1815 .find(|msg| msg.role == Role::Assistant)
1816 .map(|msg| msg.id);
1817
1818 if let Some(message_id) = last_assistant_message_id {
1819 self.report_message_feedback(message_id, feedback, cx)
1820 } else {
1821 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1822 let serialized_thread = self.serialize(cx);
1823 let thread_id = self.id().clone();
1824 let client = self.project.read(cx).client();
1825 self.feedback = Some(feedback);
1826 cx.notify();
1827
1828 cx.background_spawn(async move {
1829 let final_project_snapshot = final_project_snapshot.await;
1830 let serialized_thread = serialized_thread.await?;
1831 let thread_data = serde_json::to_value(serialized_thread)
1832 .unwrap_or_else(|_| serde_json::Value::Null);
1833
1834 let rating = match feedback {
1835 ThreadFeedback::Positive => "positive",
1836 ThreadFeedback::Negative => "negative",
1837 };
1838 telemetry::event!(
1839 "Assistant Thread Rated",
1840 rating,
1841 thread_id,
1842 thread_data,
1843 final_project_snapshot
1844 );
1845 client.telemetry().flush_events();
1846
1847 Ok(())
1848 })
1849 }
1850 }
1851
1852 /// Create a snapshot of the current project state including git information and unsaved buffers.
1853 fn project_snapshot(
1854 project: Entity<Project>,
1855 cx: &mut Context<Self>,
1856 ) -> Task<Arc<ProjectSnapshot>> {
1857 let git_store = project.read(cx).git_store().clone();
1858 let worktree_snapshots: Vec<_> = project
1859 .read(cx)
1860 .visible_worktrees(cx)
1861 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1862 .collect();
1863
1864 cx.spawn(async move |_, cx| {
1865 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1866
1867 let mut unsaved_buffers = Vec::new();
1868 cx.update(|app_cx| {
1869 let buffer_store = project.read(app_cx).buffer_store();
1870 for buffer_handle in buffer_store.read(app_cx).buffers() {
1871 let buffer = buffer_handle.read(app_cx);
1872 if buffer.is_dirty() {
1873 if let Some(file) = buffer.file() {
1874 let path = file.path().to_string_lossy().to_string();
1875 unsaved_buffers.push(path);
1876 }
1877 }
1878 }
1879 })
1880 .ok();
1881
1882 Arc::new(ProjectSnapshot {
1883 worktree_snapshots,
1884 unsaved_buffer_paths: unsaved_buffers,
1885 timestamp: Utc::now(),
1886 })
1887 })
1888 }
1889
1890 fn worktree_snapshot(
1891 worktree: Entity<project::Worktree>,
1892 git_store: Entity<GitStore>,
1893 cx: &App,
1894 ) -> Task<WorktreeSnapshot> {
1895 cx.spawn(async move |cx| {
1896 // Get worktree path and snapshot
1897 let worktree_info = cx.update(|app_cx| {
1898 let worktree = worktree.read(app_cx);
1899 let path = worktree.abs_path().to_string_lossy().to_string();
1900 let snapshot = worktree.snapshot();
1901 (path, snapshot)
1902 });
1903
1904 let Ok((worktree_path, _snapshot)) = worktree_info else {
1905 return WorktreeSnapshot {
1906 worktree_path: String::new(),
1907 git_state: None,
1908 };
1909 };
1910
1911 let git_state = git_store
1912 .update(cx, |git_store, cx| {
1913 git_store
1914 .repositories()
1915 .values()
1916 .find(|repo| {
1917 repo.read(cx)
1918 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1919 .is_some()
1920 })
1921 .cloned()
1922 })
1923 .ok()
1924 .flatten()
1925 .map(|repo| {
1926 repo.update(cx, |repo, _| {
1927 let current_branch =
1928 repo.branch.as_ref().map(|branch| branch.name.to_string());
1929 repo.send_job(None, |state, _| async move {
1930 let RepositoryState::Local { backend, .. } = state else {
1931 return GitState {
1932 remote_url: None,
1933 head_sha: None,
1934 current_branch,
1935 diff: None,
1936 };
1937 };
1938
1939 let remote_url = backend.remote_url("origin");
1940 let head_sha = backend.head_sha();
1941 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1942
1943 GitState {
1944 remote_url,
1945 head_sha,
1946 current_branch,
1947 diff,
1948 }
1949 })
1950 })
1951 });
1952
1953 let git_state = match git_state {
1954 Some(git_state) => match git_state.ok() {
1955 Some(git_state) => git_state.await.ok(),
1956 None => None,
1957 },
1958 None => None,
1959 };
1960
1961 WorktreeSnapshot {
1962 worktree_path,
1963 git_state,
1964 }
1965 })
1966 }
1967
1968 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1969 let mut markdown = Vec::new();
1970
1971 if let Some(summary) = self.summary() {
1972 writeln!(markdown, "# {summary}\n")?;
1973 };
1974
1975 for message in self.messages() {
1976 writeln!(
1977 markdown,
1978 "## {role}\n",
1979 role = match message.role {
1980 Role::User => "User",
1981 Role::Assistant => "Assistant",
1982 Role::System => "System",
1983 }
1984 )?;
1985
1986 if !message.context.is_empty() {
1987 writeln!(markdown, "{}", message.context)?;
1988 }
1989
1990 for segment in &message.segments {
1991 match segment {
1992 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1993 MessageSegment::Thinking { text, .. } => {
1994 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1995 }
1996 MessageSegment::RedactedThinking(_) => {}
1997 }
1998 }
1999
2000 for tool_use in self.tool_uses_for_message(message.id, cx) {
2001 writeln!(
2002 markdown,
2003 "**Use Tool: {} ({})**",
2004 tool_use.name, tool_use.id
2005 )?;
2006 writeln!(markdown, "```json")?;
2007 writeln!(
2008 markdown,
2009 "{}",
2010 serde_json::to_string_pretty(&tool_use.input)?
2011 )?;
2012 writeln!(markdown, "```")?;
2013 }
2014
2015 for tool_result in self.tool_results_for_message(message.id) {
2016 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2017 if tool_result.is_error {
2018 write!(markdown, " (Error)")?;
2019 }
2020
2021 writeln!(markdown, "**\n")?;
2022 writeln!(markdown, "{}", tool_result.content)?;
2023 }
2024 }
2025
2026 Ok(String::from_utf8_lossy(&markdown).to_string())
2027 }
2028
2029 pub fn keep_edits_in_range(
2030 &mut self,
2031 buffer: Entity<language::Buffer>,
2032 buffer_range: Range<language::Anchor>,
2033 cx: &mut Context<Self>,
2034 ) {
2035 self.action_log.update(cx, |action_log, cx| {
2036 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2037 });
2038 }
2039
2040 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2041 self.action_log
2042 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2043 }
2044
2045 pub fn reject_edits_in_ranges(
2046 &mut self,
2047 buffer: Entity<language::Buffer>,
2048 buffer_ranges: Vec<Range<language::Anchor>>,
2049 cx: &mut Context<Self>,
2050 ) -> Task<Result<()>> {
2051 self.action_log.update(cx, |action_log, cx| {
2052 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2053 })
2054 }
2055
2056 pub fn action_log(&self) -> &Entity<ActionLog> {
2057 &self.action_log
2058 }
2059
2060 pub fn project(&self) -> &Entity<Project> {
2061 &self.project
2062 }
2063
2064 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2065 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2066 return;
2067 }
2068
2069 let now = Instant::now();
2070 if let Some(last) = self.last_auto_capture_at {
2071 if now.duration_since(last).as_secs() < 10 {
2072 return;
2073 }
2074 }
2075
2076 self.last_auto_capture_at = Some(now);
2077
2078 let thread_id = self.id().clone();
2079 let github_login = self
2080 .project
2081 .read(cx)
2082 .user_store()
2083 .read(cx)
2084 .current_user()
2085 .map(|user| user.github_login.clone());
2086 let client = self.project.read(cx).client().clone();
2087 let serialize_task = self.serialize(cx);
2088
2089 cx.background_executor()
2090 .spawn(async move {
2091 if let Ok(serialized_thread) = serialize_task.await {
2092 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2093 telemetry::event!(
2094 "Agent Thread Auto-Captured",
2095 thread_id = thread_id.to_string(),
2096 thread_data = thread_data,
2097 auto_capture_reason = "tracked_user",
2098 github_login = github_login
2099 );
2100
2101 client.telemetry().flush_events();
2102 }
2103 }
2104 })
2105 .detach();
2106 }
2107
2108 pub fn cumulative_token_usage(&self) -> TokenUsage {
2109 self.cumulative_token_usage
2110 }
2111
2112 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2113 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2114 return TotalTokenUsage::default();
2115 };
2116
2117 let max = model.model.max_token_count();
2118
2119 let index = self
2120 .messages
2121 .iter()
2122 .position(|msg| msg.id == message_id)
2123 .unwrap_or(0);
2124
2125 if index == 0 {
2126 return TotalTokenUsage { total: 0, max };
2127 }
2128
2129 let token_usage = &self
2130 .request_token_usage
2131 .get(index - 1)
2132 .cloned()
2133 .unwrap_or_default();
2134
2135 TotalTokenUsage {
2136 total: token_usage.total_tokens() as usize,
2137 max,
2138 }
2139 }
2140
2141 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2142 let model_registry = LanguageModelRegistry::read_global(cx);
2143 let Some(model) = model_registry.default_model() else {
2144 return TotalTokenUsage::default();
2145 };
2146
2147 let max = model.model.max_token_count();
2148
2149 if let Some(exceeded_error) = &self.exceeded_window_error {
2150 if model.model.id() == exceeded_error.model_id {
2151 return TotalTokenUsage {
2152 total: exceeded_error.token_count,
2153 max,
2154 };
2155 }
2156 }
2157
2158 let total = self
2159 .token_usage_at_last_message()
2160 .unwrap_or_default()
2161 .total_tokens() as usize;
2162
2163 TotalTokenUsage { total, max }
2164 }
2165
2166 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2167 self.request_token_usage
2168 .get(self.messages.len().saturating_sub(1))
2169 .or_else(|| self.request_token_usage.last())
2170 .cloned()
2171 }
2172
2173 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2174 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2175 self.request_token_usage
2176 .resize(self.messages.len(), placeholder);
2177
2178 if let Some(last) = self.request_token_usage.last_mut() {
2179 *last = token_usage;
2180 }
2181 }
2182
2183 pub fn deny_tool_use(
2184 &mut self,
2185 tool_use_id: LanguageModelToolUseId,
2186 tool_name: Arc<str>,
2187 cx: &mut Context<Self>,
2188 ) {
2189 let err = Err(anyhow::anyhow!(
2190 "Permission to run tool action denied by user"
2191 ));
2192
2193 self.tool_use
2194 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2195 self.tool_finished(tool_use_id.clone(), None, true, cx);
2196 }
2197}
2198
2199#[derive(Debug, Clone, Error)]
2200pub enum ThreadError {
2201 #[error("Payment required")]
2202 PaymentRequired,
2203 #[error("Max monthly spend reached")]
2204 MaxMonthlySpendReached,
2205 #[error("Model request limit reached")]
2206 ModelRequestLimitReached { plan: Plan },
2207 #[error("Message {header}: {message}")]
2208 Message {
2209 header: SharedString,
2210 message: SharedString,
2211 },
2212}
2213
2214#[derive(Debug, Clone)]
2215pub enum ThreadEvent {
2216 ShowError(ThreadError),
2217 UsageUpdated(RequestUsage),
2218 StreamedCompletion,
2219 StreamedAssistantText(MessageId, String),
2220 StreamedAssistantThinking(MessageId, String),
2221 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2222 MessageAdded(MessageId),
2223 MessageEdited(MessageId),
2224 MessageDeleted(MessageId),
2225 SummaryGenerated,
2226 SummaryChanged,
2227 UsePendingTools {
2228 tool_uses: Vec<PendingToolUse>,
2229 },
2230 ToolFinished {
2231 #[allow(unused)]
2232 tool_use_id: LanguageModelToolUseId,
2233 /// The pending tool use that corresponds to this tool.
2234 pending_tool_use: Option<PendingToolUse>,
2235 },
2236 CheckpointChanged,
2237 ToolConfirmationNeeded,
2238}
2239
2240impl EventEmitter<ThreadEvent> for Thread {}
2241
2242struct PendingCompletion {
2243 id: usize,
2244 _task: Task<()>,
2245}
2246
2247#[cfg(test)]
2248mod tests {
2249 use super::*;
2250 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2251 use assistant_settings::AssistantSettings;
2252 use context_server::ContextServerSettings;
2253 use editor::EditorSettings;
2254 use gpui::TestAppContext;
2255 use project::{FakeFs, Project};
2256 use prompt_store::PromptBuilder;
2257 use serde_json::json;
2258 use settings::{Settings, SettingsStore};
2259 use std::sync::Arc;
2260 use theme::ThemeSettings;
2261 use util::path;
2262 use workspace::Workspace;
2263
2264 #[gpui::test]
2265 async fn test_message_with_context(cx: &mut TestAppContext) {
2266 init_test_settings(cx);
2267
2268 let project = create_test_project(
2269 cx,
2270 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2271 )
2272 .await;
2273
2274 let (_workspace, _thread_store, thread, context_store) =
2275 setup_test_environment(cx, project.clone()).await;
2276
2277 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2278 .await
2279 .unwrap();
2280
2281 let context =
2282 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2283
2284 // Insert user message with context
2285 let message_id = thread.update(cx, |thread, cx| {
2286 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2287 });
2288
2289 // Check content and context in message object
2290 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2291
2292 // Use different path format strings based on platform for the test
2293 #[cfg(windows)]
2294 let path_part = r"test\code.rs";
2295 #[cfg(not(windows))]
2296 let path_part = "test/code.rs";
2297
2298 let expected_context = format!(
2299 r#"
2300<context>
2301The following items were attached by the user. You don't need to use other tools to read them.
2302
2303<files>
2304```rs {path_part}
2305fn main() {{
2306 println!("Hello, world!");
2307}}
2308```
2309</files>
2310</context>
2311"#
2312 );
2313
2314 assert_eq!(message.role, Role::User);
2315 assert_eq!(message.segments.len(), 1);
2316 assert_eq!(
2317 message.segments[0],
2318 MessageSegment::Text("Please explain this code".to_string())
2319 );
2320 assert_eq!(message.context, expected_context);
2321
2322 // Check message in request
2323 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2324
2325 assert_eq!(request.messages.len(), 2);
2326 let expected_full_message = format!("{}Please explain this code", expected_context);
2327 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2328 }
2329
2330 #[gpui::test]
2331 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2332 init_test_settings(cx);
2333
2334 let project = create_test_project(
2335 cx,
2336 json!({
2337 "file1.rs": "fn function1() {}\n",
2338 "file2.rs": "fn function2() {}\n",
2339 "file3.rs": "fn function3() {}\n",
2340 }),
2341 )
2342 .await;
2343
2344 let (_, _thread_store, thread, context_store) =
2345 setup_test_environment(cx, project.clone()).await;
2346
2347 // Open files individually
2348 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2349 .await
2350 .unwrap();
2351 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2352 .await
2353 .unwrap();
2354 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2355 .await
2356 .unwrap();
2357
2358 // Get the context objects
2359 let contexts = context_store.update(cx, |store, _| store.context().clone());
2360 assert_eq!(contexts.len(), 3);
2361
2362 // First message with context 1
2363 let message1_id = thread.update(cx, |thread, cx| {
2364 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2365 });
2366
2367 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2368 let message2_id = thread.update(cx, |thread, cx| {
2369 thread.insert_user_message(
2370 "Message 2",
2371 vec![contexts[0].clone(), contexts[1].clone()],
2372 None,
2373 cx,
2374 )
2375 });
2376
2377 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2378 let message3_id = thread.update(cx, |thread, cx| {
2379 thread.insert_user_message(
2380 "Message 3",
2381 vec![
2382 contexts[0].clone(),
2383 contexts[1].clone(),
2384 contexts[2].clone(),
2385 ],
2386 None,
2387 cx,
2388 )
2389 });
2390
2391 // Check what contexts are included in each message
2392 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2393 (
2394 thread.message(message1_id).unwrap().clone(),
2395 thread.message(message2_id).unwrap().clone(),
2396 thread.message(message3_id).unwrap().clone(),
2397 )
2398 });
2399
2400 // First message should include context 1
2401 assert!(message1.context.contains("file1.rs"));
2402
2403 // Second message should include only context 2 (not 1)
2404 assert!(!message2.context.contains("file1.rs"));
2405 assert!(message2.context.contains("file2.rs"));
2406
2407 // Third message should include only context 3 (not 1 or 2)
2408 assert!(!message3.context.contains("file1.rs"));
2409 assert!(!message3.context.contains("file2.rs"));
2410 assert!(message3.context.contains("file3.rs"));
2411
2412 // Check entire request to make sure all contexts are properly included
2413 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2414
2415 // The request should contain all 3 messages
2416 assert_eq!(request.messages.len(), 4);
2417
2418 // Check that the contexts are properly formatted in each message
2419 assert!(request.messages[1].string_contents().contains("file1.rs"));
2420 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2421 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2422
2423 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2424 assert!(request.messages[2].string_contents().contains("file2.rs"));
2425 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2426
2427 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2428 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2429 assert!(request.messages[3].string_contents().contains("file3.rs"));
2430 }
2431
2432 #[gpui::test]
2433 async fn test_message_without_files(cx: &mut TestAppContext) {
2434 init_test_settings(cx);
2435
2436 let project = create_test_project(
2437 cx,
2438 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2439 )
2440 .await;
2441
2442 let (_, _thread_store, thread, _context_store) =
2443 setup_test_environment(cx, project.clone()).await;
2444
2445 // Insert user message without any context (empty context vector)
2446 let message_id = thread.update(cx, |thread, cx| {
2447 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2448 });
2449
2450 // Check content and context in message object
2451 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2452
2453 // Context should be empty when no files are included
2454 assert_eq!(message.role, Role::User);
2455 assert_eq!(message.segments.len(), 1);
2456 assert_eq!(
2457 message.segments[0],
2458 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2459 );
2460 assert_eq!(message.context, "");
2461
2462 // Check message in request
2463 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2464
2465 assert_eq!(request.messages.len(), 2);
2466 assert_eq!(
2467 request.messages[1].string_contents(),
2468 "What is the best way to learn Rust?"
2469 );
2470
2471 // Add second message, also without context
2472 let message2_id = thread.update(cx, |thread, cx| {
2473 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2474 });
2475
2476 let message2 =
2477 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2478 assert_eq!(message2.context, "");
2479
2480 // Check that both messages appear in the request
2481 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2482
2483 assert_eq!(request.messages.len(), 3);
2484 assert_eq!(
2485 request.messages[1].string_contents(),
2486 "What is the best way to learn Rust?"
2487 );
2488 assert_eq!(
2489 request.messages[2].string_contents(),
2490 "Are there any good books?"
2491 );
2492 }
2493
2494 #[gpui::test]
2495 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2496 init_test_settings(cx);
2497
2498 let project = create_test_project(
2499 cx,
2500 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2501 )
2502 .await;
2503
2504 let (_workspace, _thread_store, thread, context_store) =
2505 setup_test_environment(cx, project.clone()).await;
2506
2507 // Open buffer and add it to context
2508 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2509 .await
2510 .unwrap();
2511
2512 let context =
2513 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2514
2515 // Insert user message with the buffer as context
2516 thread.update(cx, |thread, cx| {
2517 thread.insert_user_message("Explain this code", vec![context], None, cx)
2518 });
2519
2520 // Create a request and check that it doesn't have a stale buffer warning yet
2521 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2522
2523 // Make sure we don't have a stale file warning yet
2524 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2525 msg.string_contents()
2526 .contains("These files changed since last read:")
2527 });
2528 assert!(
2529 !has_stale_warning,
2530 "Should not have stale buffer warning before buffer is modified"
2531 );
2532
2533 // Modify the buffer
2534 buffer.update(cx, |buffer, cx| {
2535 // Find a position at the end of line 1
2536 buffer.edit(
2537 [(1..1, "\n println!(\"Added a new line\");\n")],
2538 None,
2539 cx,
2540 );
2541 });
2542
2543 // Insert another user message without context
2544 thread.update(cx, |thread, cx| {
2545 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2546 });
2547
2548 // Create a new request and check for the stale buffer warning
2549 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2550
2551 // We should have a stale file warning as the last message
2552 let last_message = new_request
2553 .messages
2554 .last()
2555 .expect("Request should have messages");
2556
2557 // The last message should be the stale buffer notification
2558 assert_eq!(last_message.role, Role::User);
2559
2560 // Check the exact content of the message
2561 let expected_content = "These files changed since last read:\n- code.rs\n";
2562 assert_eq!(
2563 last_message.string_contents(),
2564 expected_content,
2565 "Last message should be exactly the stale buffer notification"
2566 );
2567 }
2568
2569 fn init_test_settings(cx: &mut TestAppContext) {
2570 cx.update(|cx| {
2571 let settings_store = SettingsStore::test(cx);
2572 cx.set_global(settings_store);
2573 language::init(cx);
2574 Project::init_settings(cx);
2575 AssistantSettings::register(cx);
2576 prompt_store::init(cx);
2577 thread_store::init(cx);
2578 workspace::init_settings(cx);
2579 ThemeSettings::register(cx);
2580 ContextServerSettings::register(cx);
2581 EditorSettings::register(cx);
2582 });
2583 }
2584
2585 // Helper to create a test project with test files
2586 async fn create_test_project(
2587 cx: &mut TestAppContext,
2588 files: serde_json::Value,
2589 ) -> Entity<Project> {
2590 let fs = FakeFs::new(cx.executor());
2591 fs.insert_tree(path!("/test"), files).await;
2592 Project::test(fs, [path!("/test").as_ref()], cx).await
2593 }
2594
2595 async fn setup_test_environment(
2596 cx: &mut TestAppContext,
2597 project: Entity<Project>,
2598 ) -> (
2599 Entity<Workspace>,
2600 Entity<ThreadStore>,
2601 Entity<Thread>,
2602 Entity<ContextStore>,
2603 ) {
2604 let (workspace, cx) =
2605 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2606
2607 let thread_store = cx
2608 .update(|_, cx| {
2609 ThreadStore::load(
2610 project.clone(),
2611 cx.new(|_| ToolWorkingSet::default()),
2612 Arc::new(PromptBuilder::new(None).unwrap()),
2613 cx,
2614 )
2615 })
2616 .await
2617 .unwrap();
2618
2619 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2620 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2621
2622 (workspace, thread_store, thread, context_store)
2623 }
2624
2625 async fn add_file_to_context(
2626 project: &Entity<Project>,
2627 context_store: &Entity<ContextStore>,
2628 path: &str,
2629 cx: &mut TestAppContext,
2630 ) -> Result<Entity<language::Buffer>> {
2631 let buffer_path = project
2632 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2633 .unwrap();
2634
2635 let buffer = project
2636 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2637 .await
2638 .unwrap();
2639
2640 context_store
2641 .update(cx, |store, cx| {
2642 store.add_file_from_buffer(buffer.clone(), cx)
2643 })
2644 .await?;
2645
2646 Ok(buffer)
2647 }
2648}