1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100}
101
102impl Message {
103 /// Returns whether the message contains any meaningful text that should be displayed
104 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
105 pub fn should_display_content(&self) -> bool {
106 self.segments.iter().all(|segment| segment.should_display())
107 }
108
109 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
110 if let Some(MessageSegment::Thinking {
111 text: segment,
112 signature: current_signature,
113 }) = self.segments.last_mut()
114 {
115 if let Some(signature) = signature {
116 *current_signature = Some(signature);
117 }
118 segment.push_str(text);
119 } else {
120 self.segments.push(MessageSegment::Thinking {
121 text: text.to_string(),
122 signature,
123 });
124 }
125 }
126
127 pub fn push_text(&mut self, text: &str) {
128 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
129 segment.push_str(text);
130 } else {
131 self.segments.push(MessageSegment::Text(text.to_string()));
132 }
133 }
134
135 pub fn to_string(&self) -> String {
136 let mut result = String::new();
137
138 if !self.context.is_empty() {
139 result.push_str(&self.context);
140 }
141
142 for segment in &self.segments {
143 match segment {
144 MessageSegment::Text(text) => result.push_str(text),
145 MessageSegment::Thinking { text, .. } => {
146 result.push_str("<think>\n");
147 result.push_str(text);
148 result.push_str("\n</think>");
149 }
150 MessageSegment::RedactedThinking(_) => {}
151 }
152 }
153
154 result
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MessageSegment {
160 Text(String),
161 Thinking {
162 text: String,
163 signature: Option<String>,
164 },
165 RedactedThinking(Vec<u8>),
166}
167
168impl MessageSegment {
169 pub fn should_display(&self) -> bool {
170 // We add USING_TOOL_MARKER when making a request that includes tool uses
171 // without non-whitespace text around them, and this can cause the model
172 // to mimic the pattern, so we consider those segments not displayable.
173 match self {
174 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
175 Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct ExceededWindowError {
324 /// Model used when last message exceeded context window
325 model_id: LanguageModelId,
326 /// Token count including last message
327 token_count: usize,
328}
329
330impl Thread {
331 pub fn new(
332 project: Entity<Project>,
333 tools: Entity<ToolWorkingSet>,
334 prompt_builder: Arc<PromptBuilder>,
335 system_prompt: SharedProjectContext,
336 cx: &mut Context<Self>,
337 ) -> Self {
338 Self {
339 id: ThreadId::new(),
340 updated_at: Utc::now(),
341 summary: None,
342 pending_summary: Task::ready(None),
343 detailed_summary_state: DetailedSummaryState::NotGenerated,
344 messages: Vec::new(),
345 next_message_id: MessageId(0),
346 last_prompt_id: PromptId::new(),
347 context: BTreeMap::default(),
348 context_by_message: HashMap::default(),
349 project_context: system_prompt,
350 checkpoints_by_message: HashMap::default(),
351 completion_count: 0,
352 pending_completions: Vec::new(),
353 project: project.clone(),
354 prompt_builder,
355 tools: tools.clone(),
356 last_restore_checkpoint: None,
357 pending_checkpoint: None,
358 tool_use: ToolUseState::new(tools.clone()),
359 action_log: cx.new(|_| ActionLog::new(project.clone())),
360 initial_project_snapshot: {
361 let project_snapshot = Self::project_snapshot(project, cx);
362 cx.foreground_executor()
363 .spawn(async move { Some(project_snapshot.await) })
364 .shared()
365 },
366 request_token_usage: Vec::new(),
367 cumulative_token_usage: TokenUsage::default(),
368 exceeded_window_error: None,
369 feedback: None,
370 message_feedback: HashMap::default(),
371 last_auto_capture_at: None,
372 request_callback: None,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use =
393 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
394
395 Self {
396 id,
397 updated_at: serialized.updated_at,
398 summary: Some(serialized.summary),
399 pending_summary: Task::ready(None),
400 detailed_summary_state: serialized.detailed_summary_state,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 context: message.context,
421 })
422 .collect(),
423 next_message_id,
424 last_prompt_id: PromptId::new(),
425 context: BTreeMap::default(),
426 context_by_message: HashMap::default(),
427 project_context,
428 checkpoints_by_message: HashMap::default(),
429 completion_count: 0,
430 pending_completions: Vec::new(),
431 last_restore_checkpoint: None,
432 pending_checkpoint: None,
433 project: project.clone(),
434 prompt_builder,
435 tools,
436 tool_use,
437 action_log: cx.new(|_| ActionLog::new(project)),
438 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
439 request_token_usage: serialized.request_token_usage,
440 cumulative_token_usage: serialized.cumulative_token_usage,
441 exceeded_window_error: None,
442 feedback: None,
443 message_feedback: HashMap::default(),
444 last_auto_capture_at: None,
445 request_callback: None,
446 }
447 }
448
449 pub fn set_request_callback(
450 &mut self,
451 callback: impl 'static
452 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
453 ) {
454 self.request_callback = Some(Box::new(callback));
455 }
456
457 pub fn id(&self) -> &ThreadId {
458 &self.id
459 }
460
461 pub fn is_empty(&self) -> bool {
462 self.messages.is_empty()
463 }
464
465 pub fn updated_at(&self) -> DateTime<Utc> {
466 self.updated_at
467 }
468
469 pub fn touch_updated_at(&mut self) {
470 self.updated_at = Utc::now();
471 }
472
473 pub fn advance_prompt_id(&mut self) {
474 self.last_prompt_id = PromptId::new();
475 }
476
477 pub fn summary(&self) -> Option<SharedString> {
478 self.summary.clone()
479 }
480
481 pub fn project_context(&self) -> SharedProjectContext {
482 self.project_context.clone()
483 }
484
485 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
486
487 pub fn summary_or_default(&self) -> SharedString {
488 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
489 }
490
491 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
492 let Some(current_summary) = &self.summary else {
493 // Don't allow setting summary until generated
494 return;
495 };
496
497 let mut new_summary = new_summary.into();
498
499 if new_summary.is_empty() {
500 new_summary = Self::DEFAULT_SUMMARY;
501 }
502
503 if current_summary != &new_summary {
504 self.summary = Some(new_summary);
505 cx.emit(ThreadEvent::SummaryChanged);
506 }
507 }
508
509 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
510 self.latest_detailed_summary()
511 .unwrap_or_else(|| self.text().into())
512 }
513
514 fn latest_detailed_summary(&self) -> Option<SharedString> {
515 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
516 Some(text.clone())
517 } else {
518 None
519 }
520 }
521
522 pub fn message(&self, id: MessageId) -> Option<&Message> {
523 self.messages.iter().find(|message| message.id == id)
524 }
525
526 pub fn messages(&self) -> impl Iterator<Item = &Message> {
527 self.messages.iter()
528 }
529
530 pub fn is_generating(&self) -> bool {
531 !self.pending_completions.is_empty() || !self.all_tools_finished()
532 }
533
534 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
535 &self.tools
536 }
537
538 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
539 self.tool_use
540 .pending_tool_uses()
541 .into_iter()
542 .find(|tool_use| &tool_use.id == id)
543 }
544
545 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
546 self.tool_use
547 .pending_tool_uses()
548 .into_iter()
549 .filter(|tool_use| tool_use.status.needs_confirmation())
550 }
551
552 pub fn has_pending_tool_uses(&self) -> bool {
553 !self.tool_use.pending_tool_uses().is_empty()
554 }
555
556 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
557 self.checkpoints_by_message.get(&id).cloned()
558 }
559
560 pub fn restore_checkpoint(
561 &mut self,
562 checkpoint: ThreadCheckpoint,
563 cx: &mut Context<Self>,
564 ) -> Task<Result<()>> {
565 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
566 message_id: checkpoint.message_id,
567 });
568 cx.emit(ThreadEvent::CheckpointChanged);
569 cx.notify();
570
571 let git_store = self.project().read(cx).git_store().clone();
572 let restore = git_store.update(cx, |git_store, cx| {
573 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
574 });
575
576 cx.spawn(async move |this, cx| {
577 let result = restore.await;
578 this.update(cx, |this, cx| {
579 if let Err(err) = result.as_ref() {
580 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
581 message_id: checkpoint.message_id,
582 error: err.to_string(),
583 });
584 } else {
585 this.truncate(checkpoint.message_id, cx);
586 this.last_restore_checkpoint = None;
587 }
588 this.pending_checkpoint = None;
589 cx.emit(ThreadEvent::CheckpointChanged);
590 cx.notify();
591 })?;
592 result
593 })
594 }
595
596 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
597 let pending_checkpoint = if self.is_generating() {
598 return;
599 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
600 checkpoint
601 } else {
602 return;
603 };
604
605 let git_store = self.project.read(cx).git_store().clone();
606 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
607 cx.spawn(async move |this, cx| match final_checkpoint.await {
608 Ok(final_checkpoint) => {
609 let equal = git_store
610 .update(cx, |store, cx| {
611 store.compare_checkpoints(
612 pending_checkpoint.git_checkpoint.clone(),
613 final_checkpoint.clone(),
614 cx,
615 )
616 })?
617 .await
618 .unwrap_or(false);
619
620 if equal {
621 git_store
622 .update(cx, |store, cx| {
623 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
624 })?
625 .detach();
626 } else {
627 this.update(cx, |this, cx| {
628 this.insert_checkpoint(pending_checkpoint, cx)
629 })?;
630 }
631
632 git_store
633 .update(cx, |store, cx| {
634 store.delete_checkpoint(final_checkpoint, cx)
635 })?
636 .detach();
637
638 Ok(())
639 }
640 Err(_) => this.update(cx, |this, cx| {
641 this.insert_checkpoint(pending_checkpoint, cx)
642 }),
643 })
644 .detach();
645 }
646
647 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
648 self.checkpoints_by_message
649 .insert(checkpoint.message_id, checkpoint);
650 cx.emit(ThreadEvent::CheckpointChanged);
651 cx.notify();
652 }
653
654 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
655 self.last_restore_checkpoint.as_ref()
656 }
657
658 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
659 let Some(message_ix) = self
660 .messages
661 .iter()
662 .rposition(|message| message.id == message_id)
663 else {
664 return;
665 };
666 for deleted_message in self.messages.drain(message_ix..) {
667 self.context_by_message.remove(&deleted_message.id);
668 self.checkpoints_by_message.remove(&deleted_message.id);
669 }
670 cx.notify();
671 }
672
673 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
674 self.context_by_message
675 .get(&id)
676 .into_iter()
677 .flat_map(|context| {
678 context
679 .iter()
680 .filter_map(|context_id| self.context.get(&context_id))
681 })
682 }
683
684 /// Returns whether all of the tool uses have finished running.
685 pub fn all_tools_finished(&self) -> bool {
686 // If the only pending tool uses left are the ones with errors, then
687 // that means that we've finished running all of the pending tools.
688 self.tool_use
689 .pending_tool_uses()
690 .iter()
691 .all(|tool_use| tool_use.status.is_error())
692 }
693
694 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
695 self.tool_use.tool_uses_for_message(id, cx)
696 }
697
698 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
699 self.tool_use.tool_results_for_message(id)
700 }
701
702 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
703 self.tool_use.tool_result(id)
704 }
705
706 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
707 Some(&self.tool_use.tool_result(id)?.content)
708 }
709
710 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
711 self.tool_use.tool_result_card(id).cloned()
712 }
713
714 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
715 self.tool_use.message_has_tool_results(message_id)
716 }
717
718 /// Filter out contexts that have already been included in previous messages
719 pub fn filter_new_context<'a>(
720 &self,
721 context: impl Iterator<Item = &'a AssistantContext>,
722 ) -> impl Iterator<Item = &'a AssistantContext> {
723 context.filter(|ctx| self.is_context_new(ctx))
724 }
725
726 fn is_context_new(&self, context: &AssistantContext) -> bool {
727 !self.context.contains_key(&context.id())
728 }
729
730 pub fn insert_user_message(
731 &mut self,
732 text: impl Into<String>,
733 context: Vec<AssistantContext>,
734 git_checkpoint: Option<GitStoreCheckpoint>,
735 cx: &mut Context<Self>,
736 ) -> MessageId {
737 let text = text.into();
738
739 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
740
741 let new_context: Vec<_> = context
742 .into_iter()
743 .filter(|ctx| self.is_context_new(ctx))
744 .collect();
745
746 if !new_context.is_empty() {
747 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
748 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
749 message.context = context_string;
750 }
751 }
752
753 self.action_log.update(cx, |log, cx| {
754 // Track all buffers added as context
755 for ctx in &new_context {
756 match ctx {
757 AssistantContext::File(file_ctx) => {
758 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
759 }
760 AssistantContext::Directory(dir_ctx) => {
761 for context_buffer in &dir_ctx.context_buffers {
762 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
763 }
764 }
765 AssistantContext::Symbol(symbol_ctx) => {
766 log.buffer_added_as_context(
767 symbol_ctx.context_symbol.buffer.clone(),
768 cx,
769 );
770 }
771 AssistantContext::Excerpt(excerpt_context) => {
772 log.buffer_added_as_context(
773 excerpt_context.context_buffer.buffer.clone(),
774 cx,
775 );
776 }
777 AssistantContext::FetchedUrl(_)
778 | AssistantContext::Thread(_)
779 | AssistantContext::Rules(_) => {}
780 }
781 }
782 });
783 }
784
785 let context_ids = new_context
786 .iter()
787 .map(|context| context.id())
788 .collect::<Vec<_>>();
789 self.context.extend(
790 new_context
791 .into_iter()
792 .map(|context| (context.id(), context)),
793 );
794 self.context_by_message.insert(message_id, context_ids);
795
796 if let Some(git_checkpoint) = git_checkpoint {
797 self.pending_checkpoint = Some(ThreadCheckpoint {
798 message_id,
799 git_checkpoint,
800 });
801 }
802
803 self.auto_capture_telemetry(cx);
804
805 message_id
806 }
807
808 pub fn insert_message(
809 &mut self,
810 role: Role,
811 segments: Vec<MessageSegment>,
812 cx: &mut Context<Self>,
813 ) -> MessageId {
814 let id = self.next_message_id.post_inc();
815 self.messages.push(Message {
816 id,
817 role,
818 segments,
819 context: String::new(),
820 });
821 self.touch_updated_at();
822 cx.emit(ThreadEvent::MessageAdded(id));
823 id
824 }
825
826 pub fn edit_message(
827 &mut self,
828 id: MessageId,
829 new_role: Role,
830 new_segments: Vec<MessageSegment>,
831 cx: &mut Context<Self>,
832 ) -> bool {
833 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
834 return false;
835 };
836 message.role = new_role;
837 message.segments = new_segments;
838 self.touch_updated_at();
839 cx.emit(ThreadEvent::MessageEdited(id));
840 true
841 }
842
843 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
844 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
845 return false;
846 };
847 self.messages.remove(index);
848 self.context_by_message.remove(&id);
849 self.touch_updated_at();
850 cx.emit(ThreadEvent::MessageDeleted(id));
851 true
852 }
853
854 /// Returns the representation of this [`Thread`] in a textual form.
855 ///
856 /// This is the representation we use when attaching a thread as context to another thread.
857 pub fn text(&self) -> String {
858 let mut text = String::new();
859
860 for message in &self.messages {
861 text.push_str(match message.role {
862 language_model::Role::User => "User:",
863 language_model::Role::Assistant => "Assistant:",
864 language_model::Role::System => "System:",
865 });
866 text.push('\n');
867
868 for segment in &message.segments {
869 match segment {
870 MessageSegment::Text(content) => text.push_str(content),
871 MessageSegment::Thinking { text: content, .. } => {
872 text.push_str(&format!("<think>{}</think>", content))
873 }
874 MessageSegment::RedactedThinking(_) => {}
875 }
876 }
877 text.push('\n');
878 }
879
880 text
881 }
882
883 /// Serializes this thread into a format for storage or telemetry.
884 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
885 let initial_project_snapshot = self.initial_project_snapshot.clone();
886 cx.spawn(async move |this, cx| {
887 let initial_project_snapshot = initial_project_snapshot.await;
888 this.read_with(cx, |this, cx| SerializedThread {
889 version: SerializedThread::VERSION.to_string(),
890 summary: this.summary_or_default(),
891 updated_at: this.updated_at(),
892 messages: this
893 .messages()
894 .map(|message| SerializedMessage {
895 id: message.id,
896 role: message.role,
897 segments: message
898 .segments
899 .iter()
900 .map(|segment| match segment {
901 MessageSegment::Text(text) => {
902 SerializedMessageSegment::Text { text: text.clone() }
903 }
904 MessageSegment::Thinking { text, signature } => {
905 SerializedMessageSegment::Thinking {
906 text: text.clone(),
907 signature: signature.clone(),
908 }
909 }
910 MessageSegment::RedactedThinking(data) => {
911 SerializedMessageSegment::RedactedThinking {
912 data: data.clone(),
913 }
914 }
915 })
916 .collect(),
917 tool_uses: this
918 .tool_uses_for_message(message.id, cx)
919 .into_iter()
920 .map(|tool_use| SerializedToolUse {
921 id: tool_use.id,
922 name: tool_use.name,
923 input: tool_use.input,
924 })
925 .collect(),
926 tool_results: this
927 .tool_results_for_message(message.id)
928 .into_iter()
929 .map(|tool_result| SerializedToolResult {
930 tool_use_id: tool_result.tool_use_id.clone(),
931 is_error: tool_result.is_error,
932 content: tool_result.content.clone(),
933 })
934 .collect(),
935 context: message.context.clone(),
936 })
937 .collect(),
938 initial_project_snapshot,
939 cumulative_token_usage: this.cumulative_token_usage,
940 request_token_usage: this.request_token_usage.clone(),
941 detailed_summary_state: this.detailed_summary_state.clone(),
942 exceeded_window_error: this.exceeded_window_error.clone(),
943 })
944 })
945 }
946
947 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
948 let mut request = self.to_completion_request(cx);
949 if model.supports_tools() {
950 request.tools = {
951 let mut tools = Vec::new();
952 tools.extend(
953 self.tools()
954 .read(cx)
955 .enabled_tools(cx)
956 .into_iter()
957 .filter_map(|tool| {
958 // Skip tools that cannot be supported
959 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
960 Some(LanguageModelRequestTool {
961 name: tool.name(),
962 description: tool.description(),
963 input_schema,
964 })
965 }),
966 );
967
968 tools
969 };
970 }
971
972 self.stream_completion(request, model, cx);
973 }
974
975 pub fn used_tools_since_last_user_message(&self) -> bool {
976 for message in self.messages.iter().rev() {
977 if self.tool_use.message_has_tool_results(message.id) {
978 return true;
979 } else if message.role == Role::User {
980 return false;
981 }
982 }
983
984 false
985 }
986
987 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
988 let mut request = LanguageModelRequest {
989 thread_id: Some(self.id.to_string()),
990 prompt_id: Some(self.last_prompt_id.to_string()),
991 messages: vec![],
992 tools: Vec::new(),
993 stop: Vec::new(),
994 temperature: None,
995 };
996
997 if let Some(project_context) = self.project_context.borrow().as_ref() {
998 match self
999 .prompt_builder
1000 .generate_assistant_system_prompt(project_context)
1001 {
1002 Err(err) => {
1003 let message = format!("{err:?}").into();
1004 log::error!("{message}");
1005 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1006 header: "Error generating system prompt".into(),
1007 message,
1008 }));
1009 }
1010 Ok(system_prompt) => {
1011 request.messages.push(LanguageModelRequestMessage {
1012 role: Role::System,
1013 content: vec![MessageContent::Text(system_prompt)],
1014 cache: true,
1015 });
1016 }
1017 }
1018 } else {
1019 let message = "Context for system prompt unexpectedly not ready.".into();
1020 log::error!("{message}");
1021 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1022 header: "Error generating system prompt".into(),
1023 message,
1024 }));
1025 }
1026
1027 for message in &self.messages {
1028 let mut request_message = LanguageModelRequestMessage {
1029 role: message.role,
1030 content: Vec::new(),
1031 cache: false,
1032 };
1033
1034 self.tool_use
1035 .attach_tool_results(message.id, &mut request_message);
1036
1037 if !message.context.is_empty() {
1038 request_message
1039 .content
1040 .push(MessageContent::Text(message.context.to_string()));
1041 }
1042
1043 for segment in &message.segments {
1044 match segment {
1045 MessageSegment::Text(text) => {
1046 if !text.is_empty() {
1047 request_message
1048 .content
1049 .push(MessageContent::Text(text.into()));
1050 }
1051 }
1052 MessageSegment::Thinking { text, signature } => {
1053 if !text.is_empty() {
1054 request_message.content.push(MessageContent::Thinking {
1055 text: text.into(),
1056 signature: signature.clone(),
1057 });
1058 }
1059 }
1060 MessageSegment::RedactedThinking(data) => {
1061 request_message
1062 .content
1063 .push(MessageContent::RedactedThinking(data.clone()));
1064 }
1065 };
1066 }
1067
1068 self.tool_use
1069 .attach_tool_uses(message.id, &mut request_message);
1070
1071 request.messages.push(request_message);
1072 }
1073
1074 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1075 if let Some(last) = request.messages.last_mut() {
1076 last.cache = true;
1077 }
1078
1079 self.attached_tracked_files_state(&mut request.messages, cx);
1080
1081 request
1082 }
1083
1084 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1085 let mut request = LanguageModelRequest {
1086 thread_id: None,
1087 prompt_id: None,
1088 messages: vec![],
1089 tools: Vec::new(),
1090 stop: Vec::new(),
1091 temperature: None,
1092 };
1093
1094 for message in &self.messages {
1095 let mut request_message = LanguageModelRequestMessage {
1096 role: message.role,
1097 content: Vec::new(),
1098 cache: false,
1099 };
1100
1101 // Skip tool results during summarization.
1102 if self.tool_use.message_has_tool_results(message.id) {
1103 continue;
1104 }
1105
1106 for segment in &message.segments {
1107 match segment {
1108 MessageSegment::Text(text) => request_message
1109 .content
1110 .push(MessageContent::Text(text.clone())),
1111 MessageSegment::Thinking { .. } => {}
1112 MessageSegment::RedactedThinking(_) => {}
1113 }
1114 }
1115
1116 if request_message.content.is_empty() {
1117 continue;
1118 }
1119
1120 request.messages.push(request_message);
1121 }
1122
1123 request.messages.push(LanguageModelRequestMessage {
1124 role: Role::User,
1125 content: vec![MessageContent::Text(added_user_message)],
1126 cache: false,
1127 });
1128
1129 request
1130 }
1131
1132 fn attached_tracked_files_state(
1133 &self,
1134 messages: &mut Vec<LanguageModelRequestMessage>,
1135 cx: &App,
1136 ) {
1137 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1138
1139 let mut stale_message = String::new();
1140
1141 let action_log = self.action_log.read(cx);
1142
1143 for stale_file in action_log.stale_buffers(cx) {
1144 let Some(file) = stale_file.read(cx).file() else {
1145 continue;
1146 };
1147
1148 if stale_message.is_empty() {
1149 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1150 }
1151
1152 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1153 }
1154
1155 let mut content = Vec::with_capacity(2);
1156
1157 if !stale_message.is_empty() {
1158 content.push(stale_message.into());
1159 }
1160
1161 if !content.is_empty() {
1162 let context_message = LanguageModelRequestMessage {
1163 role: Role::User,
1164 content,
1165 cache: false,
1166 };
1167
1168 messages.push(context_message);
1169 }
1170 }
1171
1172 pub fn stream_completion(
1173 &mut self,
1174 request: LanguageModelRequest,
1175 model: Arc<dyn LanguageModel>,
1176 cx: &mut Context<Self>,
1177 ) {
1178 let pending_completion_id = post_inc(&mut self.completion_count);
1179 let mut request_callback_parameters = if self.request_callback.is_some() {
1180 Some((request.clone(), Vec::new()))
1181 } else {
1182 None
1183 };
1184 let prompt_id = self.last_prompt_id.clone();
1185 let tool_use_metadata = ToolUseMetadata {
1186 model: model.clone(),
1187 thread_id: self.id.clone(),
1188 prompt_id: prompt_id.clone(),
1189 };
1190
1191 let task = cx.spawn(async move |thread, cx| {
1192 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1193 let initial_token_usage =
1194 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1195 let stream_completion = async {
1196 let (mut events, usage) = stream_completion_future.await?;
1197
1198 let mut stop_reason = StopReason::EndTurn;
1199 let mut current_token_usage = TokenUsage::default();
1200
1201 if let Some(usage) = usage {
1202 thread
1203 .update(cx, |_thread, cx| {
1204 cx.emit(ThreadEvent::UsageUpdated(usage));
1205 })
1206 .ok();
1207 }
1208
1209 while let Some(event) = events.next().await {
1210 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1211 response_events
1212 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1213 }
1214
1215 let event = event?;
1216
1217 thread.update(cx, |thread, cx| {
1218 match event {
1219 LanguageModelCompletionEvent::StartMessage { .. } => {
1220 thread.insert_message(
1221 Role::Assistant,
1222 vec![MessageSegment::Text(String::new())],
1223 cx,
1224 );
1225 }
1226 LanguageModelCompletionEvent::Stop(reason) => {
1227 stop_reason = reason;
1228 }
1229 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1230 thread.update_token_usage_at_last_message(token_usage);
1231 thread.cumulative_token_usage = thread.cumulative_token_usage
1232 + token_usage
1233 - current_token_usage;
1234 current_token_usage = token_usage;
1235 }
1236 LanguageModelCompletionEvent::Text(chunk) => {
1237 if let Some(last_message) = thread.messages.last_mut() {
1238 if last_message.role == Role::Assistant {
1239 last_message.push_text(&chunk);
1240 cx.emit(ThreadEvent::StreamedAssistantText(
1241 last_message.id,
1242 chunk,
1243 ));
1244 } else {
1245 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1246 // of a new Assistant response.
1247 //
1248 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1249 // will result in duplicating the text of the chunk in the rendered Markdown.
1250 thread.insert_message(
1251 Role::Assistant,
1252 vec![MessageSegment::Text(chunk.to_string())],
1253 cx,
1254 );
1255 };
1256 }
1257 }
1258 LanguageModelCompletionEvent::Thinking {
1259 text: chunk,
1260 signature,
1261 } => {
1262 if let Some(last_message) = thread.messages.last_mut() {
1263 if last_message.role == Role::Assistant {
1264 last_message.push_thinking(&chunk, signature);
1265 cx.emit(ThreadEvent::StreamedAssistantThinking(
1266 last_message.id,
1267 chunk,
1268 ));
1269 } else {
1270 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1271 // of a new Assistant response.
1272 //
1273 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1274 // will result in duplicating the text of the chunk in the rendered Markdown.
1275 thread.insert_message(
1276 Role::Assistant,
1277 vec![MessageSegment::Thinking {
1278 text: chunk.to_string(),
1279 signature,
1280 }],
1281 cx,
1282 );
1283 };
1284 }
1285 }
1286 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1287 let last_assistant_message_id = thread
1288 .messages
1289 .iter_mut()
1290 .rfind(|message| message.role == Role::Assistant)
1291 .map(|message| message.id)
1292 .unwrap_or_else(|| {
1293 thread.insert_message(Role::Assistant, vec![], cx)
1294 });
1295
1296 thread.tool_use.request_tool_use(
1297 last_assistant_message_id,
1298 tool_use,
1299 tool_use_metadata.clone(),
1300 cx,
1301 );
1302 }
1303 }
1304
1305 thread.touch_updated_at();
1306 cx.emit(ThreadEvent::StreamedCompletion);
1307 cx.notify();
1308
1309 thread.auto_capture_telemetry(cx);
1310 })?;
1311
1312 smol::future::yield_now().await;
1313 }
1314
1315 thread.update(cx, |thread, cx| {
1316 thread
1317 .pending_completions
1318 .retain(|completion| completion.id != pending_completion_id);
1319
1320 // If there is a response without tool use, summarize the message. Otherwise,
1321 // allow two tool uses before summarizing.
1322 if thread.summary.is_none()
1323 && thread.messages.len() >= 2
1324 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1325 {
1326 thread.summarize(cx);
1327 }
1328 })?;
1329
1330 anyhow::Ok(stop_reason)
1331 };
1332
1333 let result = stream_completion.await;
1334
1335 thread
1336 .update(cx, |thread, cx| {
1337 thread.finalize_pending_checkpoint(cx);
1338 match result.as_ref() {
1339 Ok(stop_reason) => match stop_reason {
1340 StopReason::ToolUse => {
1341 let tool_uses = thread.use_pending_tools(cx);
1342 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1343 }
1344 StopReason::EndTurn => {}
1345 StopReason::MaxTokens => {}
1346 },
1347 Err(error) => {
1348 if error.is::<PaymentRequiredError>() {
1349 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1350 } else if error.is::<MaxMonthlySpendReachedError>() {
1351 cx.emit(ThreadEvent::ShowError(
1352 ThreadError::MaxMonthlySpendReached,
1353 ));
1354 } else if let Some(error) =
1355 error.downcast_ref::<ModelRequestLimitReachedError>()
1356 {
1357 cx.emit(ThreadEvent::ShowError(
1358 ThreadError::ModelRequestLimitReached { plan: error.plan },
1359 ));
1360 } else if let Some(known_error) =
1361 error.downcast_ref::<LanguageModelKnownError>()
1362 {
1363 match known_error {
1364 LanguageModelKnownError::ContextWindowLimitExceeded {
1365 tokens,
1366 } => {
1367 thread.exceeded_window_error = Some(ExceededWindowError {
1368 model_id: model.id(),
1369 token_count: *tokens,
1370 });
1371 cx.notify();
1372 }
1373 }
1374 } else {
1375 let error_message = error
1376 .chain()
1377 .map(|err| err.to_string())
1378 .collect::<Vec<_>>()
1379 .join("\n");
1380 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1381 header: "Error interacting with language model".into(),
1382 message: SharedString::from(error_message.clone()),
1383 }));
1384 }
1385
1386 thread.cancel_last_completion(cx);
1387 }
1388 }
1389 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1390
1391 if let Some((request_callback, (request, response_events))) = thread
1392 .request_callback
1393 .as_mut()
1394 .zip(request_callback_parameters.as_ref())
1395 {
1396 request_callback(request, response_events);
1397 }
1398
1399 thread.auto_capture_telemetry(cx);
1400
1401 if let Ok(initial_usage) = initial_token_usage {
1402 let usage = thread.cumulative_token_usage - initial_usage;
1403
1404 telemetry::event!(
1405 "Assistant Thread Completion",
1406 thread_id = thread.id().to_string(),
1407 prompt_id = prompt_id,
1408 model = model.telemetry_id(),
1409 model_provider = model.provider_id().to_string(),
1410 input_tokens = usage.input_tokens,
1411 output_tokens = usage.output_tokens,
1412 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1413 cache_read_input_tokens = usage.cache_read_input_tokens,
1414 );
1415 }
1416 })
1417 .ok();
1418 });
1419
1420 self.pending_completions.push(PendingCompletion {
1421 id: pending_completion_id,
1422 _task: task,
1423 });
1424 }
1425
1426 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1427 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1428 return;
1429 };
1430
1431 if !model.provider.is_authenticated(cx) {
1432 return;
1433 }
1434
1435 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1436 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1437 If the conversation is about a specific subject, include it in the title. \
1438 Be descriptive. DO NOT speak in the first person.";
1439
1440 let request = self.to_summarize_request(added_user_message.into());
1441
1442 self.pending_summary = cx.spawn(async move |this, cx| {
1443 async move {
1444 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1445 let (mut messages, usage) = stream.await?;
1446
1447 if let Some(usage) = usage {
1448 this.update(cx, |_thread, cx| {
1449 cx.emit(ThreadEvent::UsageUpdated(usage));
1450 })
1451 .ok();
1452 }
1453
1454 let mut new_summary = String::new();
1455 while let Some(message) = messages.stream.next().await {
1456 let text = message?;
1457 let mut lines = text.lines();
1458 new_summary.extend(lines.next());
1459
1460 // Stop if the LLM generated multiple lines.
1461 if lines.next().is_some() {
1462 break;
1463 }
1464 }
1465
1466 this.update(cx, |this, cx| {
1467 if !new_summary.is_empty() {
1468 this.summary = Some(new_summary.into());
1469 }
1470
1471 cx.emit(ThreadEvent::SummaryGenerated);
1472 })?;
1473
1474 anyhow::Ok(())
1475 }
1476 .log_err()
1477 .await
1478 });
1479 }
1480
1481 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1482 let last_message_id = self.messages.last().map(|message| message.id)?;
1483
1484 match &self.detailed_summary_state {
1485 DetailedSummaryState::Generating { message_id, .. }
1486 | DetailedSummaryState::Generated { message_id, .. }
1487 if *message_id == last_message_id =>
1488 {
1489 // Already up-to-date
1490 return None;
1491 }
1492 _ => {}
1493 }
1494
1495 let ConfiguredModel { model, provider } =
1496 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1497
1498 if !provider.is_authenticated(cx) {
1499 return None;
1500 }
1501
1502 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1503 1. A brief overview of what was discussed\n\
1504 2. Key facts or information discovered\n\
1505 3. Outcomes or conclusions reached\n\
1506 4. Any action items or next steps if any\n\
1507 Format it in Markdown with headings and bullet points.";
1508
1509 let request = self.to_summarize_request(added_user_message.into());
1510
1511 let task = cx.spawn(async move |thread, cx| {
1512 let stream = model.stream_completion_text(request, &cx);
1513 let Some(mut messages) = stream.await.log_err() else {
1514 thread
1515 .update(cx, |this, _cx| {
1516 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1517 })
1518 .log_err();
1519
1520 return;
1521 };
1522
1523 let mut new_detailed_summary = String::new();
1524
1525 while let Some(chunk) = messages.stream.next().await {
1526 if let Some(chunk) = chunk.log_err() {
1527 new_detailed_summary.push_str(&chunk);
1528 }
1529 }
1530
1531 thread
1532 .update(cx, |this, _cx| {
1533 this.detailed_summary_state = DetailedSummaryState::Generated {
1534 text: new_detailed_summary.into(),
1535 message_id: last_message_id,
1536 };
1537 })
1538 .log_err();
1539 });
1540
1541 self.detailed_summary_state = DetailedSummaryState::Generating {
1542 message_id: last_message_id,
1543 };
1544
1545 Some(task)
1546 }
1547
1548 pub fn is_generating_detailed_summary(&self) -> bool {
1549 matches!(
1550 self.detailed_summary_state,
1551 DetailedSummaryState::Generating { .. }
1552 )
1553 }
1554
1555 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1556 self.auto_capture_telemetry(cx);
1557 let request = self.to_completion_request(cx);
1558 let messages = Arc::new(request.messages);
1559 let pending_tool_uses = self
1560 .tool_use
1561 .pending_tool_uses()
1562 .into_iter()
1563 .filter(|tool_use| tool_use.status.is_idle())
1564 .cloned()
1565 .collect::<Vec<_>>();
1566
1567 for tool_use in pending_tool_uses.iter() {
1568 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1569 if tool.needs_confirmation(&tool_use.input, cx)
1570 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1571 {
1572 self.tool_use.confirm_tool_use(
1573 tool_use.id.clone(),
1574 tool_use.ui_text.clone(),
1575 tool_use.input.clone(),
1576 messages.clone(),
1577 tool,
1578 );
1579 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1580 } else {
1581 self.run_tool(
1582 tool_use.id.clone(),
1583 tool_use.ui_text.clone(),
1584 tool_use.input.clone(),
1585 &messages,
1586 tool,
1587 cx,
1588 );
1589 }
1590 }
1591 }
1592
1593 pending_tool_uses
1594 }
1595
1596 pub fn run_tool(
1597 &mut self,
1598 tool_use_id: LanguageModelToolUseId,
1599 ui_text: impl Into<SharedString>,
1600 input: serde_json::Value,
1601 messages: &[LanguageModelRequestMessage],
1602 tool: Arc<dyn Tool>,
1603 cx: &mut Context<Thread>,
1604 ) {
1605 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1606 self.tool_use
1607 .run_pending_tool(tool_use_id, ui_text.into(), task);
1608 }
1609
1610 fn spawn_tool_use(
1611 &mut self,
1612 tool_use_id: LanguageModelToolUseId,
1613 messages: &[LanguageModelRequestMessage],
1614 input: serde_json::Value,
1615 tool: Arc<dyn Tool>,
1616 cx: &mut Context<Thread>,
1617 ) -> Task<()> {
1618 let tool_name: Arc<str> = tool.name().into();
1619
1620 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1621 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1622 } else {
1623 tool.run(
1624 input,
1625 messages,
1626 self.project.clone(),
1627 self.action_log.clone(),
1628 cx,
1629 )
1630 };
1631
1632 // Store the card separately if it exists
1633 if let Some(card) = tool_result.card.clone() {
1634 self.tool_use
1635 .insert_tool_result_card(tool_use_id.clone(), card);
1636 }
1637
1638 cx.spawn({
1639 async move |thread: WeakEntity<Thread>, cx| {
1640 let output = tool_result.output.await;
1641
1642 thread
1643 .update(cx, |thread, cx| {
1644 let pending_tool_use = thread.tool_use.insert_tool_output(
1645 tool_use_id.clone(),
1646 tool_name,
1647 output,
1648 cx,
1649 );
1650 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1651 })
1652 .ok();
1653 }
1654 })
1655 }
1656
1657 fn tool_finished(
1658 &mut self,
1659 tool_use_id: LanguageModelToolUseId,
1660 pending_tool_use: Option<PendingToolUse>,
1661 canceled: bool,
1662 cx: &mut Context<Self>,
1663 ) {
1664 if self.all_tools_finished() {
1665 let model_registry = LanguageModelRegistry::read_global(cx);
1666 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1667 self.attach_tool_results(cx);
1668 if !canceled {
1669 self.send_to_model(model, cx);
1670 }
1671 }
1672 }
1673
1674 cx.emit(ThreadEvent::ToolFinished {
1675 tool_use_id,
1676 pending_tool_use,
1677 });
1678 }
1679
1680 /// Insert an empty message to be populated with tool results upon send.
1681 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1682 // Tool results are assumed to be waiting on the next message id, so they will populate
1683 // this empty message before sending to model. Would prefer this to be more straightforward.
1684 self.insert_message(Role::User, vec![], cx);
1685 self.auto_capture_telemetry(cx);
1686 }
1687
1688 /// Cancels the last pending completion, if there are any pending.
1689 ///
1690 /// Returns whether a completion was canceled.
1691 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1692 let canceled = if self.pending_completions.pop().is_some() {
1693 true
1694 } else {
1695 let mut canceled = false;
1696 for pending_tool_use in self.tool_use.cancel_pending() {
1697 canceled = true;
1698 self.tool_finished(
1699 pending_tool_use.id.clone(),
1700 Some(pending_tool_use),
1701 true,
1702 cx,
1703 );
1704 }
1705 canceled
1706 };
1707 self.finalize_pending_checkpoint(cx);
1708 canceled
1709 }
1710
1711 pub fn feedback(&self) -> Option<ThreadFeedback> {
1712 self.feedback
1713 }
1714
1715 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1716 self.message_feedback.get(&message_id).copied()
1717 }
1718
1719 pub fn report_message_feedback(
1720 &mut self,
1721 message_id: MessageId,
1722 feedback: ThreadFeedback,
1723 cx: &mut Context<Self>,
1724 ) -> Task<Result<()>> {
1725 if self.message_feedback.get(&message_id) == Some(&feedback) {
1726 return Task::ready(Ok(()));
1727 }
1728
1729 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1730 let serialized_thread = self.serialize(cx);
1731 let thread_id = self.id().clone();
1732 let client = self.project.read(cx).client();
1733
1734 let enabled_tool_names: Vec<String> = self
1735 .tools()
1736 .read(cx)
1737 .enabled_tools(cx)
1738 .iter()
1739 .map(|tool| tool.name().to_string())
1740 .collect();
1741
1742 self.message_feedback.insert(message_id, feedback);
1743
1744 cx.notify();
1745
1746 let message_content = self
1747 .message(message_id)
1748 .map(|msg| msg.to_string())
1749 .unwrap_or_default();
1750
1751 cx.background_spawn(async move {
1752 let final_project_snapshot = final_project_snapshot.await;
1753 let serialized_thread = serialized_thread.await?;
1754 let thread_data =
1755 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1756
1757 let rating = match feedback {
1758 ThreadFeedback::Positive => "positive",
1759 ThreadFeedback::Negative => "negative",
1760 };
1761 telemetry::event!(
1762 "Assistant Thread Rated",
1763 rating,
1764 thread_id,
1765 enabled_tool_names,
1766 message_id = message_id.0,
1767 message_content,
1768 thread_data,
1769 final_project_snapshot
1770 );
1771 client.telemetry().flush_events();
1772
1773 Ok(())
1774 })
1775 }
1776
1777 pub fn report_feedback(
1778 &mut self,
1779 feedback: ThreadFeedback,
1780 cx: &mut Context<Self>,
1781 ) -> Task<Result<()>> {
1782 let last_assistant_message_id = self
1783 .messages
1784 .iter()
1785 .rev()
1786 .find(|msg| msg.role == Role::Assistant)
1787 .map(|msg| msg.id);
1788
1789 if let Some(message_id) = last_assistant_message_id {
1790 self.report_message_feedback(message_id, feedback, cx)
1791 } else {
1792 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1793 let serialized_thread = self.serialize(cx);
1794 let thread_id = self.id().clone();
1795 let client = self.project.read(cx).client();
1796 self.feedback = Some(feedback);
1797 cx.notify();
1798
1799 cx.background_spawn(async move {
1800 let final_project_snapshot = final_project_snapshot.await;
1801 let serialized_thread = serialized_thread.await?;
1802 let thread_data = serde_json::to_value(serialized_thread)
1803 .unwrap_or_else(|_| serde_json::Value::Null);
1804
1805 let rating = match feedback {
1806 ThreadFeedback::Positive => "positive",
1807 ThreadFeedback::Negative => "negative",
1808 };
1809 telemetry::event!(
1810 "Assistant Thread Rated",
1811 rating,
1812 thread_id,
1813 thread_data,
1814 final_project_snapshot
1815 );
1816 client.telemetry().flush_events();
1817
1818 Ok(())
1819 })
1820 }
1821 }
1822
1823 /// Create a snapshot of the current project state including git information and unsaved buffers.
1824 fn project_snapshot(
1825 project: Entity<Project>,
1826 cx: &mut Context<Self>,
1827 ) -> Task<Arc<ProjectSnapshot>> {
1828 let git_store = project.read(cx).git_store().clone();
1829 let worktree_snapshots: Vec<_> = project
1830 .read(cx)
1831 .visible_worktrees(cx)
1832 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1833 .collect();
1834
1835 cx.spawn(async move |_, cx| {
1836 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1837
1838 let mut unsaved_buffers = Vec::new();
1839 cx.update(|app_cx| {
1840 let buffer_store = project.read(app_cx).buffer_store();
1841 for buffer_handle in buffer_store.read(app_cx).buffers() {
1842 let buffer = buffer_handle.read(app_cx);
1843 if buffer.is_dirty() {
1844 if let Some(file) = buffer.file() {
1845 let path = file.path().to_string_lossy().to_string();
1846 unsaved_buffers.push(path);
1847 }
1848 }
1849 }
1850 })
1851 .ok();
1852
1853 Arc::new(ProjectSnapshot {
1854 worktree_snapshots,
1855 unsaved_buffer_paths: unsaved_buffers,
1856 timestamp: Utc::now(),
1857 })
1858 })
1859 }
1860
1861 fn worktree_snapshot(
1862 worktree: Entity<project::Worktree>,
1863 git_store: Entity<GitStore>,
1864 cx: &App,
1865 ) -> Task<WorktreeSnapshot> {
1866 cx.spawn(async move |cx| {
1867 // Get worktree path and snapshot
1868 let worktree_info = cx.update(|app_cx| {
1869 let worktree = worktree.read(app_cx);
1870 let path = worktree.abs_path().to_string_lossy().to_string();
1871 let snapshot = worktree.snapshot();
1872 (path, snapshot)
1873 });
1874
1875 let Ok((worktree_path, _snapshot)) = worktree_info else {
1876 return WorktreeSnapshot {
1877 worktree_path: String::new(),
1878 git_state: None,
1879 };
1880 };
1881
1882 let git_state = git_store
1883 .update(cx, |git_store, cx| {
1884 git_store
1885 .repositories()
1886 .values()
1887 .find(|repo| {
1888 repo.read(cx)
1889 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1890 .is_some()
1891 })
1892 .cloned()
1893 })
1894 .ok()
1895 .flatten()
1896 .map(|repo| {
1897 repo.update(cx, |repo, _| {
1898 let current_branch =
1899 repo.branch.as_ref().map(|branch| branch.name.to_string());
1900 repo.send_job(None, |state, _| async move {
1901 let RepositoryState::Local { backend, .. } = state else {
1902 return GitState {
1903 remote_url: None,
1904 head_sha: None,
1905 current_branch,
1906 diff: None,
1907 };
1908 };
1909
1910 let remote_url = backend.remote_url("origin");
1911 let head_sha = backend.head_sha();
1912 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1913
1914 GitState {
1915 remote_url,
1916 head_sha,
1917 current_branch,
1918 diff,
1919 }
1920 })
1921 })
1922 });
1923
1924 let git_state = match git_state {
1925 Some(git_state) => match git_state.ok() {
1926 Some(git_state) => git_state.await.ok(),
1927 None => None,
1928 },
1929 None => None,
1930 };
1931
1932 WorktreeSnapshot {
1933 worktree_path,
1934 git_state,
1935 }
1936 })
1937 }
1938
1939 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1940 let mut markdown = Vec::new();
1941
1942 if let Some(summary) = self.summary() {
1943 writeln!(markdown, "# {summary}\n")?;
1944 };
1945
1946 for message in self.messages() {
1947 writeln!(
1948 markdown,
1949 "## {role}\n",
1950 role = match message.role {
1951 Role::User => "User",
1952 Role::Assistant => "Assistant",
1953 Role::System => "System",
1954 }
1955 )?;
1956
1957 if !message.context.is_empty() {
1958 writeln!(markdown, "{}", message.context)?;
1959 }
1960
1961 for segment in &message.segments {
1962 match segment {
1963 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1964 MessageSegment::Thinking { text, .. } => {
1965 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1966 }
1967 MessageSegment::RedactedThinking(_) => {}
1968 }
1969 }
1970
1971 for tool_use in self.tool_uses_for_message(message.id, cx) {
1972 writeln!(
1973 markdown,
1974 "**Use Tool: {} ({})**",
1975 tool_use.name, tool_use.id
1976 )?;
1977 writeln!(markdown, "```json")?;
1978 writeln!(
1979 markdown,
1980 "{}",
1981 serde_json::to_string_pretty(&tool_use.input)?
1982 )?;
1983 writeln!(markdown, "```")?;
1984 }
1985
1986 for tool_result in self.tool_results_for_message(message.id) {
1987 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1988 if tool_result.is_error {
1989 write!(markdown, " (Error)")?;
1990 }
1991
1992 writeln!(markdown, "**\n")?;
1993 writeln!(markdown, "{}", tool_result.content)?;
1994 }
1995 }
1996
1997 Ok(String::from_utf8_lossy(&markdown).to_string())
1998 }
1999
2000 pub fn keep_edits_in_range(
2001 &mut self,
2002 buffer: Entity<language::Buffer>,
2003 buffer_range: Range<language::Anchor>,
2004 cx: &mut Context<Self>,
2005 ) {
2006 self.action_log.update(cx, |action_log, cx| {
2007 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2008 });
2009 }
2010
2011 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2012 self.action_log
2013 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2014 }
2015
2016 pub fn reject_edits_in_ranges(
2017 &mut self,
2018 buffer: Entity<language::Buffer>,
2019 buffer_ranges: Vec<Range<language::Anchor>>,
2020 cx: &mut Context<Self>,
2021 ) -> Task<Result<()>> {
2022 self.action_log.update(cx, |action_log, cx| {
2023 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2024 })
2025 }
2026
2027 pub fn action_log(&self) -> &Entity<ActionLog> {
2028 &self.action_log
2029 }
2030
2031 pub fn project(&self) -> &Entity<Project> {
2032 &self.project
2033 }
2034
2035 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2036 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2037 return;
2038 }
2039
2040 let now = Instant::now();
2041 if let Some(last) = self.last_auto_capture_at {
2042 if now.duration_since(last).as_secs() < 10 {
2043 return;
2044 }
2045 }
2046
2047 self.last_auto_capture_at = Some(now);
2048
2049 let thread_id = self.id().clone();
2050 let github_login = self
2051 .project
2052 .read(cx)
2053 .user_store()
2054 .read(cx)
2055 .current_user()
2056 .map(|user| user.github_login.clone());
2057 let client = self.project.read(cx).client().clone();
2058 let serialize_task = self.serialize(cx);
2059
2060 cx.background_executor()
2061 .spawn(async move {
2062 if let Ok(serialized_thread) = serialize_task.await {
2063 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2064 telemetry::event!(
2065 "Agent Thread Auto-Captured",
2066 thread_id = thread_id.to_string(),
2067 thread_data = thread_data,
2068 auto_capture_reason = "tracked_user",
2069 github_login = github_login
2070 );
2071
2072 client.telemetry().flush_events();
2073 }
2074 }
2075 })
2076 .detach();
2077 }
2078
2079 pub fn cumulative_token_usage(&self) -> TokenUsage {
2080 self.cumulative_token_usage
2081 }
2082
2083 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2084 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2085 return TotalTokenUsage::default();
2086 };
2087
2088 let max = model.model.max_token_count();
2089
2090 let index = self
2091 .messages
2092 .iter()
2093 .position(|msg| msg.id == message_id)
2094 .unwrap_or(0);
2095
2096 if index == 0 {
2097 return TotalTokenUsage { total: 0, max };
2098 }
2099
2100 let token_usage = &self
2101 .request_token_usage
2102 .get(index - 1)
2103 .cloned()
2104 .unwrap_or_default();
2105
2106 TotalTokenUsage {
2107 total: token_usage.total_tokens() as usize,
2108 max,
2109 }
2110 }
2111
2112 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2113 let model_registry = LanguageModelRegistry::read_global(cx);
2114 let Some(model) = model_registry.default_model() else {
2115 return TotalTokenUsage::default();
2116 };
2117
2118 let max = model.model.max_token_count();
2119
2120 if let Some(exceeded_error) = &self.exceeded_window_error {
2121 if model.model.id() == exceeded_error.model_id {
2122 return TotalTokenUsage {
2123 total: exceeded_error.token_count,
2124 max,
2125 };
2126 }
2127 }
2128
2129 let total = self
2130 .token_usage_at_last_message()
2131 .unwrap_or_default()
2132 .total_tokens() as usize;
2133
2134 TotalTokenUsage { total, max }
2135 }
2136
2137 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2138 self.request_token_usage
2139 .get(self.messages.len().saturating_sub(1))
2140 .or_else(|| self.request_token_usage.last())
2141 .cloned()
2142 }
2143
2144 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2145 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2146 self.request_token_usage
2147 .resize(self.messages.len(), placeholder);
2148
2149 if let Some(last) = self.request_token_usage.last_mut() {
2150 *last = token_usage;
2151 }
2152 }
2153
2154 pub fn deny_tool_use(
2155 &mut self,
2156 tool_use_id: LanguageModelToolUseId,
2157 tool_name: Arc<str>,
2158 cx: &mut Context<Self>,
2159 ) {
2160 let err = Err(anyhow::anyhow!(
2161 "Permission to run tool action denied by user"
2162 ));
2163
2164 self.tool_use
2165 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2166 self.tool_finished(tool_use_id.clone(), None, true, cx);
2167 }
2168}
2169
2170#[derive(Debug, Clone, Error)]
2171pub enum ThreadError {
2172 #[error("Payment required")]
2173 PaymentRequired,
2174 #[error("Max monthly spend reached")]
2175 MaxMonthlySpendReached,
2176 #[error("Model request limit reached")]
2177 ModelRequestLimitReached { plan: Plan },
2178 #[error("Message {header}: {message}")]
2179 Message {
2180 header: SharedString,
2181 message: SharedString,
2182 },
2183}
2184
2185#[derive(Debug, Clone)]
2186pub enum ThreadEvent {
2187 ShowError(ThreadError),
2188 UsageUpdated(RequestUsage),
2189 StreamedCompletion,
2190 StreamedAssistantText(MessageId, String),
2191 StreamedAssistantThinking(MessageId, String),
2192 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2193 MessageAdded(MessageId),
2194 MessageEdited(MessageId),
2195 MessageDeleted(MessageId),
2196 SummaryGenerated,
2197 SummaryChanged,
2198 UsePendingTools {
2199 tool_uses: Vec<PendingToolUse>,
2200 },
2201 ToolFinished {
2202 #[allow(unused)]
2203 tool_use_id: LanguageModelToolUseId,
2204 /// The pending tool use that corresponds to this tool.
2205 pending_tool_use: Option<PendingToolUse>,
2206 },
2207 CheckpointChanged,
2208 ToolConfirmationNeeded,
2209}
2210
2211impl EventEmitter<ThreadEvent> for Thread {}
2212
2213struct PendingCompletion {
2214 id: usize,
2215 _task: Task<()>,
2216}
2217
2218#[cfg(test)]
2219mod tests {
2220 use super::*;
2221 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2222 use assistant_settings::AssistantSettings;
2223 use context_server::ContextServerSettings;
2224 use editor::EditorSettings;
2225 use gpui::TestAppContext;
2226 use project::{FakeFs, Project};
2227 use prompt_store::PromptBuilder;
2228 use serde_json::json;
2229 use settings::{Settings, SettingsStore};
2230 use std::sync::Arc;
2231 use theme::ThemeSettings;
2232 use util::path;
2233 use workspace::Workspace;
2234
2235 #[gpui::test]
2236 async fn test_message_with_context(cx: &mut TestAppContext) {
2237 init_test_settings(cx);
2238
2239 let project = create_test_project(
2240 cx,
2241 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2242 )
2243 .await;
2244
2245 let (_workspace, _thread_store, thread, context_store) =
2246 setup_test_environment(cx, project.clone()).await;
2247
2248 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2249 .await
2250 .unwrap();
2251
2252 let context =
2253 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2254
2255 // Insert user message with context
2256 let message_id = thread.update(cx, |thread, cx| {
2257 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2258 });
2259
2260 // Check content and context in message object
2261 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2262
2263 // Use different path format strings based on platform for the test
2264 #[cfg(windows)]
2265 let path_part = r"test\code.rs";
2266 #[cfg(not(windows))]
2267 let path_part = "test/code.rs";
2268
2269 let expected_context = format!(
2270 r#"
2271<context>
2272The following items were attached by the user. You don't need to use other tools to read them.
2273
2274<files>
2275```rs {path_part}
2276fn main() {{
2277 println!("Hello, world!");
2278}}
2279```
2280</files>
2281</context>
2282"#
2283 );
2284
2285 assert_eq!(message.role, Role::User);
2286 assert_eq!(message.segments.len(), 1);
2287 assert_eq!(
2288 message.segments[0],
2289 MessageSegment::Text("Please explain this code".to_string())
2290 );
2291 assert_eq!(message.context, expected_context);
2292
2293 // Check message in request
2294 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2295
2296 assert_eq!(request.messages.len(), 2);
2297 let expected_full_message = format!("{}Please explain this code", expected_context);
2298 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2299 }
2300
2301 #[gpui::test]
2302 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2303 init_test_settings(cx);
2304
2305 let project = create_test_project(
2306 cx,
2307 json!({
2308 "file1.rs": "fn function1() {}\n",
2309 "file2.rs": "fn function2() {}\n",
2310 "file3.rs": "fn function3() {}\n",
2311 }),
2312 )
2313 .await;
2314
2315 let (_, _thread_store, thread, context_store) =
2316 setup_test_environment(cx, project.clone()).await;
2317
2318 // Open files individually
2319 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2320 .await
2321 .unwrap();
2322 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2323 .await
2324 .unwrap();
2325 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2326 .await
2327 .unwrap();
2328
2329 // Get the context objects
2330 let contexts = context_store.update(cx, |store, _| store.context().clone());
2331 assert_eq!(contexts.len(), 3);
2332
2333 // First message with context 1
2334 let message1_id = thread.update(cx, |thread, cx| {
2335 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2336 });
2337
2338 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2339 let message2_id = thread.update(cx, |thread, cx| {
2340 thread.insert_user_message(
2341 "Message 2",
2342 vec![contexts[0].clone(), contexts[1].clone()],
2343 None,
2344 cx,
2345 )
2346 });
2347
2348 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2349 let message3_id = thread.update(cx, |thread, cx| {
2350 thread.insert_user_message(
2351 "Message 3",
2352 vec![
2353 contexts[0].clone(),
2354 contexts[1].clone(),
2355 contexts[2].clone(),
2356 ],
2357 None,
2358 cx,
2359 )
2360 });
2361
2362 // Check what contexts are included in each message
2363 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2364 (
2365 thread.message(message1_id).unwrap().clone(),
2366 thread.message(message2_id).unwrap().clone(),
2367 thread.message(message3_id).unwrap().clone(),
2368 )
2369 });
2370
2371 // First message should include context 1
2372 assert!(message1.context.contains("file1.rs"));
2373
2374 // Second message should include only context 2 (not 1)
2375 assert!(!message2.context.contains("file1.rs"));
2376 assert!(message2.context.contains("file2.rs"));
2377
2378 // Third message should include only context 3 (not 1 or 2)
2379 assert!(!message3.context.contains("file1.rs"));
2380 assert!(!message3.context.contains("file2.rs"));
2381 assert!(message3.context.contains("file3.rs"));
2382
2383 // Check entire request to make sure all contexts are properly included
2384 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2385
2386 // The request should contain all 3 messages
2387 assert_eq!(request.messages.len(), 4);
2388
2389 // Check that the contexts are properly formatted in each message
2390 assert!(request.messages[1].string_contents().contains("file1.rs"));
2391 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2392 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2393
2394 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2395 assert!(request.messages[2].string_contents().contains("file2.rs"));
2396 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2397
2398 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2399 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2400 assert!(request.messages[3].string_contents().contains("file3.rs"));
2401 }
2402
2403 #[gpui::test]
2404 async fn test_message_without_files(cx: &mut TestAppContext) {
2405 init_test_settings(cx);
2406
2407 let project = create_test_project(
2408 cx,
2409 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2410 )
2411 .await;
2412
2413 let (_, _thread_store, thread, _context_store) =
2414 setup_test_environment(cx, project.clone()).await;
2415
2416 // Insert user message without any context (empty context vector)
2417 let message_id = thread.update(cx, |thread, cx| {
2418 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2419 });
2420
2421 // Check content and context in message object
2422 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2423
2424 // Context should be empty when no files are included
2425 assert_eq!(message.role, Role::User);
2426 assert_eq!(message.segments.len(), 1);
2427 assert_eq!(
2428 message.segments[0],
2429 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2430 );
2431 assert_eq!(message.context, "");
2432
2433 // Check message in request
2434 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2435
2436 assert_eq!(request.messages.len(), 2);
2437 assert_eq!(
2438 request.messages[1].string_contents(),
2439 "What is the best way to learn Rust?"
2440 );
2441
2442 // Add second message, also without context
2443 let message2_id = thread.update(cx, |thread, cx| {
2444 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2445 });
2446
2447 let message2 =
2448 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2449 assert_eq!(message2.context, "");
2450
2451 // Check that both messages appear in the request
2452 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2453
2454 assert_eq!(request.messages.len(), 3);
2455 assert_eq!(
2456 request.messages[1].string_contents(),
2457 "What is the best way to learn Rust?"
2458 );
2459 assert_eq!(
2460 request.messages[2].string_contents(),
2461 "Are there any good books?"
2462 );
2463 }
2464
2465 #[gpui::test]
2466 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2467 init_test_settings(cx);
2468
2469 let project = create_test_project(
2470 cx,
2471 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2472 )
2473 .await;
2474
2475 let (_workspace, _thread_store, thread, context_store) =
2476 setup_test_environment(cx, project.clone()).await;
2477
2478 // Open buffer and add it to context
2479 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2480 .await
2481 .unwrap();
2482
2483 let context =
2484 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2485
2486 // Insert user message with the buffer as context
2487 thread.update(cx, |thread, cx| {
2488 thread.insert_user_message("Explain this code", vec![context], None, cx)
2489 });
2490
2491 // Create a request and check that it doesn't have a stale buffer warning yet
2492 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2493
2494 // Make sure we don't have a stale file warning yet
2495 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2496 msg.string_contents()
2497 .contains("These files changed since last read:")
2498 });
2499 assert!(
2500 !has_stale_warning,
2501 "Should not have stale buffer warning before buffer is modified"
2502 );
2503
2504 // Modify the buffer
2505 buffer.update(cx, |buffer, cx| {
2506 // Find a position at the end of line 1
2507 buffer.edit(
2508 [(1..1, "\n println!(\"Added a new line\");\n")],
2509 None,
2510 cx,
2511 );
2512 });
2513
2514 // Insert another user message without context
2515 thread.update(cx, |thread, cx| {
2516 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2517 });
2518
2519 // Create a new request and check for the stale buffer warning
2520 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2521
2522 // We should have a stale file warning as the last message
2523 let last_message = new_request
2524 .messages
2525 .last()
2526 .expect("Request should have messages");
2527
2528 // The last message should be the stale buffer notification
2529 assert_eq!(last_message.role, Role::User);
2530
2531 // Check the exact content of the message
2532 let expected_content = "These files changed since last read:\n- code.rs\n";
2533 assert_eq!(
2534 last_message.string_contents(),
2535 expected_content,
2536 "Last message should be exactly the stale buffer notification"
2537 );
2538 }
2539
2540 fn init_test_settings(cx: &mut TestAppContext) {
2541 cx.update(|cx| {
2542 let settings_store = SettingsStore::test(cx);
2543 cx.set_global(settings_store);
2544 language::init(cx);
2545 Project::init_settings(cx);
2546 AssistantSettings::register(cx);
2547 prompt_store::init(cx);
2548 thread_store::init(cx);
2549 workspace::init_settings(cx);
2550 ThemeSettings::register(cx);
2551 ContextServerSettings::register(cx);
2552 EditorSettings::register(cx);
2553 });
2554 }
2555
2556 // Helper to create a test project with test files
2557 async fn create_test_project(
2558 cx: &mut TestAppContext,
2559 files: serde_json::Value,
2560 ) -> Entity<Project> {
2561 let fs = FakeFs::new(cx.executor());
2562 fs.insert_tree(path!("/test"), files).await;
2563 Project::test(fs, [path!("/test").as_ref()], cx).await
2564 }
2565
2566 async fn setup_test_environment(
2567 cx: &mut TestAppContext,
2568 project: Entity<Project>,
2569 ) -> (
2570 Entity<Workspace>,
2571 Entity<ThreadStore>,
2572 Entity<Thread>,
2573 Entity<ContextStore>,
2574 ) {
2575 let (workspace, cx) =
2576 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2577
2578 let thread_store = cx
2579 .update(|_, cx| {
2580 ThreadStore::load(
2581 project.clone(),
2582 cx.new(|_| ToolWorkingSet::default()),
2583 Arc::new(PromptBuilder::new(None).unwrap()),
2584 cx,
2585 )
2586 })
2587 .await
2588 .unwrap();
2589
2590 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2591 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2592
2593 (workspace, thread_store, thread, context_store)
2594 }
2595
2596 async fn add_file_to_context(
2597 project: &Entity<Project>,
2598 context_store: &Entity<ContextStore>,
2599 path: &str,
2600 cx: &mut TestAppContext,
2601 ) -> Result<Entity<language::Buffer>> {
2602 let buffer_path = project
2603 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2604 .unwrap();
2605
2606 let buffer = project
2607 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2608 .await
2609 .unwrap();
2610
2611 context_store
2612 .update(cx, |store, cx| {
2613 store.add_file_from_buffer(buffer.clone(), cx)
2614 })
2615 .await?;
2616
2617 Ok(buffer)
2618 }
2619}