1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100}
101
102impl Message {
103 /// Returns whether the message contains any meaningful text that should be displayed
104 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
105 pub fn should_display_content(&self) -> bool {
106 self.segments.iter().all(|segment| segment.should_display())
107 }
108
109 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
110 if let Some(MessageSegment::Thinking {
111 text: segment,
112 signature: current_signature,
113 }) = self.segments.last_mut()
114 {
115 if let Some(signature) = signature {
116 *current_signature = Some(signature);
117 }
118 segment.push_str(text);
119 } else {
120 self.segments.push(MessageSegment::Thinking {
121 text: text.to_string(),
122 signature,
123 });
124 }
125 }
126
127 pub fn push_text(&mut self, text: &str) {
128 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
129 segment.push_str(text);
130 } else {
131 self.segments.push(MessageSegment::Text(text.to_string()));
132 }
133 }
134
135 pub fn to_string(&self) -> String {
136 let mut result = String::new();
137
138 if !self.context.is_empty() {
139 result.push_str(&self.context);
140 }
141
142 for segment in &self.segments {
143 match segment {
144 MessageSegment::Text(text) => result.push_str(text),
145 MessageSegment::Thinking { text, .. } => {
146 result.push_str("<think>\n");
147 result.push_str(text);
148 result.push_str("\n</think>");
149 }
150 MessageSegment::RedactedThinking(_) => {}
151 }
152 }
153
154 result
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MessageSegment {
160 Text(String),
161 Thinking {
162 text: String,
163 signature: Option<String>,
164 },
165 RedactedThinking(Vec<u8>),
166}
167
168impl MessageSegment {
169 pub fn should_display(&self) -> bool {
170 // We add USING_TOOL_MARKER when making a request that includes tool uses
171 // without non-whitespace text around them, and this can cause the model
172 // to mimic the pattern, so we consider those segments not displayable.
173 match self {
174 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
175 Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct ExceededWindowError {
324 /// Model used when last message exceeded context window
325 model_id: LanguageModelId,
326 /// Token count including last message
327 token_count: usize,
328}
329
330impl Thread {
331 pub fn new(
332 project: Entity<Project>,
333 tools: Entity<ToolWorkingSet>,
334 prompt_builder: Arc<PromptBuilder>,
335 system_prompt: SharedProjectContext,
336 cx: &mut Context<Self>,
337 ) -> Self {
338 Self {
339 id: ThreadId::new(),
340 updated_at: Utc::now(),
341 summary: None,
342 pending_summary: Task::ready(None),
343 detailed_summary_state: DetailedSummaryState::NotGenerated,
344 messages: Vec::new(),
345 next_message_id: MessageId(0),
346 last_prompt_id: PromptId::new(),
347 context: BTreeMap::default(),
348 context_by_message: HashMap::default(),
349 project_context: system_prompt,
350 checkpoints_by_message: HashMap::default(),
351 completion_count: 0,
352 pending_completions: Vec::new(),
353 project: project.clone(),
354 prompt_builder,
355 tools: tools.clone(),
356 last_restore_checkpoint: None,
357 pending_checkpoint: None,
358 tool_use: ToolUseState::new(tools.clone()),
359 action_log: cx.new(|_| ActionLog::new(project.clone())),
360 initial_project_snapshot: {
361 let project_snapshot = Self::project_snapshot(project, cx);
362 cx.foreground_executor()
363 .spawn(async move { Some(project_snapshot.await) })
364 .shared()
365 },
366 request_token_usage: Vec::new(),
367 cumulative_token_usage: TokenUsage::default(),
368 exceeded_window_error: None,
369 feedback: None,
370 message_feedback: HashMap::default(),
371 last_auto_capture_at: None,
372 request_callback: None,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use =
393 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
394
395 Self {
396 id,
397 updated_at: serialized.updated_at,
398 summary: Some(serialized.summary),
399 pending_summary: Task::ready(None),
400 detailed_summary_state: serialized.detailed_summary_state,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 context: message.context,
421 })
422 .collect(),
423 next_message_id,
424 last_prompt_id: PromptId::new(),
425 context: BTreeMap::default(),
426 context_by_message: HashMap::default(),
427 project_context,
428 checkpoints_by_message: HashMap::default(),
429 completion_count: 0,
430 pending_completions: Vec::new(),
431 last_restore_checkpoint: None,
432 pending_checkpoint: None,
433 project: project.clone(),
434 prompt_builder,
435 tools,
436 tool_use,
437 action_log: cx.new(|_| ActionLog::new(project)),
438 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
439 request_token_usage: serialized.request_token_usage,
440 cumulative_token_usage: serialized.cumulative_token_usage,
441 exceeded_window_error: None,
442 feedback: None,
443 message_feedback: HashMap::default(),
444 last_auto_capture_at: None,
445 request_callback: None,
446 }
447 }
448
449 pub fn set_request_callback(
450 &mut self,
451 callback: impl 'static
452 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
453 ) {
454 self.request_callback = Some(Box::new(callback));
455 }
456
457 pub fn id(&self) -> &ThreadId {
458 &self.id
459 }
460
461 pub fn is_empty(&self) -> bool {
462 self.messages.is_empty()
463 }
464
465 pub fn updated_at(&self) -> DateTime<Utc> {
466 self.updated_at
467 }
468
469 pub fn touch_updated_at(&mut self) {
470 self.updated_at = Utc::now();
471 }
472
473 pub fn advance_prompt_id(&mut self) {
474 self.last_prompt_id = PromptId::new();
475 }
476
477 pub fn summary(&self) -> Option<SharedString> {
478 self.summary.clone()
479 }
480
481 pub fn project_context(&self) -> SharedProjectContext {
482 self.project_context.clone()
483 }
484
485 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
486
487 pub fn summary_or_default(&self) -> SharedString {
488 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
489 }
490
491 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
492 let Some(current_summary) = &self.summary else {
493 // Don't allow setting summary until generated
494 return;
495 };
496
497 let mut new_summary = new_summary.into();
498
499 if new_summary.is_empty() {
500 new_summary = Self::DEFAULT_SUMMARY;
501 }
502
503 if current_summary != &new_summary {
504 self.summary = Some(new_summary);
505 cx.emit(ThreadEvent::SummaryChanged);
506 }
507 }
508
509 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
510 self.latest_detailed_summary()
511 .unwrap_or_else(|| self.text().into())
512 }
513
514 fn latest_detailed_summary(&self) -> Option<SharedString> {
515 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
516 Some(text.clone())
517 } else {
518 None
519 }
520 }
521
522 pub fn message(&self, id: MessageId) -> Option<&Message> {
523 self.messages.iter().find(|message| message.id == id)
524 }
525
526 pub fn messages(&self) -> impl Iterator<Item = &Message> {
527 self.messages.iter()
528 }
529
530 pub fn is_generating(&self) -> bool {
531 !self.pending_completions.is_empty() || !self.all_tools_finished()
532 }
533
534 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
535 &self.tools
536 }
537
538 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
539 self.tool_use
540 .pending_tool_uses()
541 .into_iter()
542 .find(|tool_use| &tool_use.id == id)
543 }
544
545 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
546 self.tool_use
547 .pending_tool_uses()
548 .into_iter()
549 .filter(|tool_use| tool_use.status.needs_confirmation())
550 }
551
552 pub fn has_pending_tool_uses(&self) -> bool {
553 !self.tool_use.pending_tool_uses().is_empty()
554 }
555
556 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
557 self.checkpoints_by_message.get(&id).cloned()
558 }
559
560 pub fn restore_checkpoint(
561 &mut self,
562 checkpoint: ThreadCheckpoint,
563 cx: &mut Context<Self>,
564 ) -> Task<Result<()>> {
565 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
566 message_id: checkpoint.message_id,
567 });
568 cx.emit(ThreadEvent::CheckpointChanged);
569 cx.notify();
570
571 let git_store = self.project().read(cx).git_store().clone();
572 let restore = git_store.update(cx, |git_store, cx| {
573 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
574 });
575
576 cx.spawn(async move |this, cx| {
577 let result = restore.await;
578 this.update(cx, |this, cx| {
579 if let Err(err) = result.as_ref() {
580 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
581 message_id: checkpoint.message_id,
582 error: err.to_string(),
583 });
584 } else {
585 this.truncate(checkpoint.message_id, cx);
586 this.last_restore_checkpoint = None;
587 }
588 this.pending_checkpoint = None;
589 cx.emit(ThreadEvent::CheckpointChanged);
590 cx.notify();
591 })?;
592 result
593 })
594 }
595
596 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
597 let pending_checkpoint = if self.is_generating() {
598 return;
599 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
600 checkpoint
601 } else {
602 return;
603 };
604
605 let git_store = self.project.read(cx).git_store().clone();
606 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
607 cx.spawn(async move |this, cx| match final_checkpoint.await {
608 Ok(final_checkpoint) => {
609 let equal = git_store
610 .update(cx, |store, cx| {
611 store.compare_checkpoints(
612 pending_checkpoint.git_checkpoint.clone(),
613 final_checkpoint.clone(),
614 cx,
615 )
616 })?
617 .await
618 .unwrap_or(false);
619
620 if equal {
621 git_store
622 .update(cx, |store, cx| {
623 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
624 })?
625 .detach();
626 } else {
627 this.update(cx, |this, cx| {
628 this.insert_checkpoint(pending_checkpoint, cx)
629 })?;
630 }
631
632 git_store
633 .update(cx, |store, cx| {
634 store.delete_checkpoint(final_checkpoint, cx)
635 })?
636 .detach();
637
638 Ok(())
639 }
640 Err(_) => this.update(cx, |this, cx| {
641 this.insert_checkpoint(pending_checkpoint, cx)
642 }),
643 })
644 .detach();
645 }
646
647 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
648 self.checkpoints_by_message
649 .insert(checkpoint.message_id, checkpoint);
650 cx.emit(ThreadEvent::CheckpointChanged);
651 cx.notify();
652 }
653
654 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
655 self.last_restore_checkpoint.as_ref()
656 }
657
658 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
659 let Some(message_ix) = self
660 .messages
661 .iter()
662 .rposition(|message| message.id == message_id)
663 else {
664 return;
665 };
666 for deleted_message in self.messages.drain(message_ix..) {
667 self.context_by_message.remove(&deleted_message.id);
668 self.checkpoints_by_message.remove(&deleted_message.id);
669 }
670 cx.notify();
671 }
672
673 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
674 self.context_by_message
675 .get(&id)
676 .into_iter()
677 .flat_map(|context| {
678 context
679 .iter()
680 .filter_map(|context_id| self.context.get(&context_id))
681 })
682 }
683
684 /// Returns whether all of the tool uses have finished running.
685 pub fn all_tools_finished(&self) -> bool {
686 // If the only pending tool uses left are the ones with errors, then
687 // that means that we've finished running all of the pending tools.
688 self.tool_use
689 .pending_tool_uses()
690 .iter()
691 .all(|tool_use| tool_use.status.is_error())
692 }
693
694 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
695 self.tool_use.tool_uses_for_message(id, cx)
696 }
697
698 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
699 self.tool_use.tool_results_for_message(id)
700 }
701
702 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
703 self.tool_use.tool_result(id)
704 }
705
706 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
707 Some(&self.tool_use.tool_result(id)?.content)
708 }
709
710 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
711 self.tool_use.tool_result_card(id).cloned()
712 }
713
714 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
715 self.tool_use.message_has_tool_results(message_id)
716 }
717
718 /// Filter out contexts that have already been included in previous messages
719 pub fn filter_new_context<'a>(
720 &self,
721 context: impl Iterator<Item = &'a AssistantContext>,
722 ) -> impl Iterator<Item = &'a AssistantContext> {
723 context.filter(|ctx| self.is_context_new(ctx))
724 }
725
726 fn is_context_new(&self, context: &AssistantContext) -> bool {
727 !self.context.contains_key(&context.id())
728 }
729
730 pub fn insert_user_message(
731 &mut self,
732 text: impl Into<String>,
733 context: Vec<AssistantContext>,
734 git_checkpoint: Option<GitStoreCheckpoint>,
735 cx: &mut Context<Self>,
736 ) -> MessageId {
737 let text = text.into();
738
739 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
740
741 let new_context: Vec<_> = context
742 .into_iter()
743 .filter(|ctx| self.is_context_new(ctx))
744 .collect();
745
746 if !new_context.is_empty() {
747 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
748 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
749 message.context = context_string;
750 }
751 }
752
753 self.action_log.update(cx, |log, cx| {
754 // Track all buffers added as context
755 for ctx in &new_context {
756 match ctx {
757 AssistantContext::File(file_ctx) => {
758 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
759 }
760 AssistantContext::Directory(dir_ctx) => {
761 for context_buffer in &dir_ctx.context_buffers {
762 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
763 }
764 }
765 AssistantContext::Symbol(symbol_ctx) => {
766 log.buffer_added_as_context(
767 symbol_ctx.context_symbol.buffer.clone(),
768 cx,
769 );
770 }
771 AssistantContext::Excerpt(excerpt_context) => {
772 log.buffer_added_as_context(
773 excerpt_context.context_buffer.buffer.clone(),
774 cx,
775 );
776 }
777 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
778 }
779 }
780 });
781 }
782
783 let context_ids = new_context
784 .iter()
785 .map(|context| context.id())
786 .collect::<Vec<_>>();
787 self.context.extend(
788 new_context
789 .into_iter()
790 .map(|context| (context.id(), context)),
791 );
792 self.context_by_message.insert(message_id, context_ids);
793
794 if let Some(git_checkpoint) = git_checkpoint {
795 self.pending_checkpoint = Some(ThreadCheckpoint {
796 message_id,
797 git_checkpoint,
798 });
799 }
800
801 self.auto_capture_telemetry(cx);
802
803 message_id
804 }
805
806 pub fn insert_message(
807 &mut self,
808 role: Role,
809 segments: Vec<MessageSegment>,
810 cx: &mut Context<Self>,
811 ) -> MessageId {
812 let id = self.next_message_id.post_inc();
813 self.messages.push(Message {
814 id,
815 role,
816 segments,
817 context: String::new(),
818 });
819 self.touch_updated_at();
820 cx.emit(ThreadEvent::MessageAdded(id));
821 id
822 }
823
824 pub fn edit_message(
825 &mut self,
826 id: MessageId,
827 new_role: Role,
828 new_segments: Vec<MessageSegment>,
829 cx: &mut Context<Self>,
830 ) -> bool {
831 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
832 return false;
833 };
834 message.role = new_role;
835 message.segments = new_segments;
836 self.touch_updated_at();
837 cx.emit(ThreadEvent::MessageEdited(id));
838 true
839 }
840
841 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
842 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
843 return false;
844 };
845 self.messages.remove(index);
846 self.context_by_message.remove(&id);
847 self.touch_updated_at();
848 cx.emit(ThreadEvent::MessageDeleted(id));
849 true
850 }
851
852 /// Returns the representation of this [`Thread`] in a textual form.
853 ///
854 /// This is the representation we use when attaching a thread as context to another thread.
855 pub fn text(&self) -> String {
856 let mut text = String::new();
857
858 for message in &self.messages {
859 text.push_str(match message.role {
860 language_model::Role::User => "User:",
861 language_model::Role::Assistant => "Assistant:",
862 language_model::Role::System => "System:",
863 });
864 text.push('\n');
865
866 for segment in &message.segments {
867 match segment {
868 MessageSegment::Text(content) => text.push_str(content),
869 MessageSegment::Thinking { text: content, .. } => {
870 text.push_str(&format!("<think>{}</think>", content))
871 }
872 MessageSegment::RedactedThinking(_) => {}
873 }
874 }
875 text.push('\n');
876 }
877
878 text
879 }
880
881 /// Serializes this thread into a format for storage or telemetry.
882 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
883 let initial_project_snapshot = self.initial_project_snapshot.clone();
884 cx.spawn(async move |this, cx| {
885 let initial_project_snapshot = initial_project_snapshot.await;
886 this.read_with(cx, |this, cx| SerializedThread {
887 version: SerializedThread::VERSION.to_string(),
888 summary: this.summary_or_default(),
889 updated_at: this.updated_at(),
890 messages: this
891 .messages()
892 .map(|message| SerializedMessage {
893 id: message.id,
894 role: message.role,
895 segments: message
896 .segments
897 .iter()
898 .map(|segment| match segment {
899 MessageSegment::Text(text) => {
900 SerializedMessageSegment::Text { text: text.clone() }
901 }
902 MessageSegment::Thinking { text, signature } => {
903 SerializedMessageSegment::Thinking {
904 text: text.clone(),
905 signature: signature.clone(),
906 }
907 }
908 MessageSegment::RedactedThinking(data) => {
909 SerializedMessageSegment::RedactedThinking {
910 data: data.clone(),
911 }
912 }
913 })
914 .collect(),
915 tool_uses: this
916 .tool_uses_for_message(message.id, cx)
917 .into_iter()
918 .map(|tool_use| SerializedToolUse {
919 id: tool_use.id,
920 name: tool_use.name,
921 input: tool_use.input,
922 })
923 .collect(),
924 tool_results: this
925 .tool_results_for_message(message.id)
926 .into_iter()
927 .map(|tool_result| SerializedToolResult {
928 tool_use_id: tool_result.tool_use_id.clone(),
929 is_error: tool_result.is_error,
930 content: tool_result.content.clone(),
931 })
932 .collect(),
933 context: message.context.clone(),
934 })
935 .collect(),
936 initial_project_snapshot,
937 cumulative_token_usage: this.cumulative_token_usage,
938 request_token_usage: this.request_token_usage.clone(),
939 detailed_summary_state: this.detailed_summary_state.clone(),
940 exceeded_window_error: this.exceeded_window_error.clone(),
941 })
942 })
943 }
944
945 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
946 let mut request = self.to_completion_request(cx);
947 if model.supports_tools() {
948 request.tools = {
949 let mut tools = Vec::new();
950 tools.extend(
951 self.tools()
952 .read(cx)
953 .enabled_tools(cx)
954 .into_iter()
955 .filter_map(|tool| {
956 // Skip tools that cannot be supported
957 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
958 Some(LanguageModelRequestTool {
959 name: tool.name(),
960 description: tool.description(),
961 input_schema,
962 })
963 }),
964 );
965
966 tools
967 };
968 }
969
970 self.stream_completion(request, model, cx);
971 }
972
973 pub fn used_tools_since_last_user_message(&self) -> bool {
974 for message in self.messages.iter().rev() {
975 if self.tool_use.message_has_tool_results(message.id) {
976 return true;
977 } else if message.role == Role::User {
978 return false;
979 }
980 }
981
982 false
983 }
984
985 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
986 let mut request = LanguageModelRequest {
987 thread_id: Some(self.id.to_string()),
988 prompt_id: Some(self.last_prompt_id.to_string()),
989 messages: vec![],
990 tools: Vec::new(),
991 stop: Vec::new(),
992 temperature: None,
993 };
994
995 if let Some(project_context) = self.project_context.borrow().as_ref() {
996 match self
997 .prompt_builder
998 .generate_assistant_system_prompt(project_context)
999 {
1000 Err(err) => {
1001 let message = format!("{err:?}").into();
1002 log::error!("{message}");
1003 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1004 header: "Error generating system prompt".into(),
1005 message,
1006 }));
1007 }
1008 Ok(system_prompt) => {
1009 request.messages.push(LanguageModelRequestMessage {
1010 role: Role::System,
1011 content: vec![MessageContent::Text(system_prompt)],
1012 cache: true,
1013 });
1014 }
1015 }
1016 } else {
1017 let message = "Context for system prompt unexpectedly not ready.".into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024
1025 for message in &self.messages {
1026 let mut request_message = LanguageModelRequestMessage {
1027 role: message.role,
1028 content: Vec::new(),
1029 cache: false,
1030 };
1031
1032 self.tool_use
1033 .attach_tool_results(message.id, &mut request_message);
1034
1035 if !message.context.is_empty() {
1036 request_message
1037 .content
1038 .push(MessageContent::Text(message.context.to_string()));
1039 }
1040
1041 for segment in &message.segments {
1042 match segment {
1043 MessageSegment::Text(text) => {
1044 if !text.is_empty() {
1045 request_message
1046 .content
1047 .push(MessageContent::Text(text.into()));
1048 }
1049 }
1050 MessageSegment::Thinking { text, signature } => {
1051 if !text.is_empty() {
1052 request_message.content.push(MessageContent::Thinking {
1053 text: text.into(),
1054 signature: signature.clone(),
1055 });
1056 }
1057 }
1058 MessageSegment::RedactedThinking(data) => {
1059 request_message
1060 .content
1061 .push(MessageContent::RedactedThinking(data.clone()));
1062 }
1063 };
1064 }
1065
1066 self.tool_use
1067 .attach_tool_uses(message.id, &mut request_message);
1068
1069 request.messages.push(request_message);
1070 }
1071
1072 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1073 if let Some(last) = request.messages.last_mut() {
1074 last.cache = true;
1075 }
1076
1077 self.attached_tracked_files_state(&mut request.messages, cx);
1078
1079 request
1080 }
1081
1082 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1083 let mut request = LanguageModelRequest {
1084 thread_id: None,
1085 prompt_id: None,
1086 messages: vec![],
1087 tools: Vec::new(),
1088 stop: Vec::new(),
1089 temperature: None,
1090 };
1091
1092 for message in &self.messages {
1093 let mut request_message = LanguageModelRequestMessage {
1094 role: message.role,
1095 content: Vec::new(),
1096 cache: false,
1097 };
1098
1099 // Skip tool results during summarization.
1100 if self.tool_use.message_has_tool_results(message.id) {
1101 continue;
1102 }
1103
1104 for segment in &message.segments {
1105 match segment {
1106 MessageSegment::Text(text) => request_message
1107 .content
1108 .push(MessageContent::Text(text.clone())),
1109 MessageSegment::Thinking { .. } => {}
1110 MessageSegment::RedactedThinking(_) => {}
1111 }
1112 }
1113
1114 if request_message.content.is_empty() {
1115 continue;
1116 }
1117
1118 request.messages.push(request_message);
1119 }
1120
1121 request.messages.push(LanguageModelRequestMessage {
1122 role: Role::User,
1123 content: vec![MessageContent::Text(added_user_message)],
1124 cache: false,
1125 });
1126
1127 request
1128 }
1129
1130 fn attached_tracked_files_state(
1131 &self,
1132 messages: &mut Vec<LanguageModelRequestMessage>,
1133 cx: &App,
1134 ) {
1135 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1136
1137 let mut stale_message = String::new();
1138
1139 let action_log = self.action_log.read(cx);
1140
1141 for stale_file in action_log.stale_buffers(cx) {
1142 let Some(file) = stale_file.read(cx).file() else {
1143 continue;
1144 };
1145
1146 if stale_message.is_empty() {
1147 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1148 }
1149
1150 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1151 }
1152
1153 let mut content = Vec::with_capacity(2);
1154
1155 if !stale_message.is_empty() {
1156 content.push(stale_message.into());
1157 }
1158
1159 if !content.is_empty() {
1160 let context_message = LanguageModelRequestMessage {
1161 role: Role::User,
1162 content,
1163 cache: false,
1164 };
1165
1166 messages.push(context_message);
1167 }
1168 }
1169
1170 pub fn stream_completion(
1171 &mut self,
1172 request: LanguageModelRequest,
1173 model: Arc<dyn LanguageModel>,
1174 cx: &mut Context<Self>,
1175 ) {
1176 let pending_completion_id = post_inc(&mut self.completion_count);
1177 let mut request_callback_parameters = if self.request_callback.is_some() {
1178 Some((request.clone(), Vec::new()))
1179 } else {
1180 None
1181 };
1182 let prompt_id = self.last_prompt_id.clone();
1183 let task = cx.spawn(async move |thread, cx| {
1184 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1185 let initial_token_usage =
1186 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1187 let stream_completion = async {
1188 let (mut events, usage) = stream_completion_future.await?;
1189
1190 let mut stop_reason = StopReason::EndTurn;
1191 let mut current_token_usage = TokenUsage::default();
1192
1193 if let Some(usage) = usage {
1194 thread
1195 .update(cx, |_thread, cx| {
1196 cx.emit(ThreadEvent::UsageUpdated(usage));
1197 })
1198 .ok();
1199 }
1200
1201 while let Some(event) = events.next().await {
1202 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1203 response_events
1204 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1205 }
1206
1207 let event = event?;
1208
1209 thread.update(cx, |thread, cx| {
1210 match event {
1211 LanguageModelCompletionEvent::StartMessage { .. } => {
1212 thread.insert_message(
1213 Role::Assistant,
1214 vec![MessageSegment::Text(String::new())],
1215 cx,
1216 );
1217 }
1218 LanguageModelCompletionEvent::Stop(reason) => {
1219 stop_reason = reason;
1220 }
1221 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1222 thread.update_token_usage_at_last_message(token_usage);
1223 thread.cumulative_token_usage = thread.cumulative_token_usage
1224 + token_usage
1225 - current_token_usage;
1226 current_token_usage = token_usage;
1227 }
1228 LanguageModelCompletionEvent::Text(chunk) => {
1229 if let Some(last_message) = thread.messages.last_mut() {
1230 if last_message.role == Role::Assistant {
1231 last_message.push_text(&chunk);
1232 cx.emit(ThreadEvent::StreamedAssistantText(
1233 last_message.id,
1234 chunk,
1235 ));
1236 } else {
1237 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1238 // of a new Assistant response.
1239 //
1240 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1241 // will result in duplicating the text of the chunk in the rendered Markdown.
1242 thread.insert_message(
1243 Role::Assistant,
1244 vec![MessageSegment::Text(chunk.to_string())],
1245 cx,
1246 );
1247 };
1248 }
1249 }
1250 LanguageModelCompletionEvent::Thinking {
1251 text: chunk,
1252 signature,
1253 } => {
1254 if let Some(last_message) = thread.messages.last_mut() {
1255 if last_message.role == Role::Assistant {
1256 last_message.push_thinking(&chunk, signature);
1257 cx.emit(ThreadEvent::StreamedAssistantThinking(
1258 last_message.id,
1259 chunk,
1260 ));
1261 } else {
1262 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1263 // of a new Assistant response.
1264 //
1265 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1266 // will result in duplicating the text of the chunk in the rendered Markdown.
1267 thread.insert_message(
1268 Role::Assistant,
1269 vec![MessageSegment::Thinking {
1270 text: chunk.to_string(),
1271 signature,
1272 }],
1273 cx,
1274 );
1275 };
1276 }
1277 }
1278 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1279 let last_assistant_message_id = thread
1280 .messages
1281 .iter_mut()
1282 .rfind(|message| message.role == Role::Assistant)
1283 .map(|message| message.id)
1284 .unwrap_or_else(|| {
1285 thread.insert_message(Role::Assistant, vec![], cx)
1286 });
1287
1288 thread.tool_use.request_tool_use(
1289 last_assistant_message_id,
1290 tool_use,
1291 cx,
1292 );
1293 }
1294 }
1295
1296 thread.touch_updated_at();
1297 cx.emit(ThreadEvent::StreamedCompletion);
1298 cx.notify();
1299
1300 thread.auto_capture_telemetry(cx);
1301 })?;
1302
1303 smol::future::yield_now().await;
1304 }
1305
1306 thread.update(cx, |thread, cx| {
1307 thread
1308 .pending_completions
1309 .retain(|completion| completion.id != pending_completion_id);
1310
1311 // If there is a response without tool use, summarize the message. Otherwise,
1312 // allow two tool uses before summarizing.
1313 if thread.summary.is_none()
1314 && thread.messages.len() >= 2
1315 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1316 {
1317 thread.summarize(cx);
1318 }
1319 })?;
1320
1321 anyhow::Ok(stop_reason)
1322 };
1323
1324 let result = stream_completion.await;
1325
1326 thread
1327 .update(cx, |thread, cx| {
1328 thread.finalize_pending_checkpoint(cx);
1329 match result.as_ref() {
1330 Ok(stop_reason) => match stop_reason {
1331 StopReason::ToolUse => {
1332 let tool_uses = thread.use_pending_tools(cx);
1333 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1334 }
1335 StopReason::EndTurn => {}
1336 StopReason::MaxTokens => {}
1337 },
1338 Err(error) => {
1339 if error.is::<PaymentRequiredError>() {
1340 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1341 } else if error.is::<MaxMonthlySpendReachedError>() {
1342 cx.emit(ThreadEvent::ShowError(
1343 ThreadError::MaxMonthlySpendReached,
1344 ));
1345 } else if let Some(error) =
1346 error.downcast_ref::<ModelRequestLimitReachedError>()
1347 {
1348 cx.emit(ThreadEvent::ShowError(
1349 ThreadError::ModelRequestLimitReached { plan: error.plan },
1350 ));
1351 } else if let Some(known_error) =
1352 error.downcast_ref::<LanguageModelKnownError>()
1353 {
1354 match known_error {
1355 LanguageModelKnownError::ContextWindowLimitExceeded {
1356 tokens,
1357 } => {
1358 thread.exceeded_window_error = Some(ExceededWindowError {
1359 model_id: model.id(),
1360 token_count: *tokens,
1361 });
1362 cx.notify();
1363 }
1364 }
1365 } else {
1366 let error_message = error
1367 .chain()
1368 .map(|err| err.to_string())
1369 .collect::<Vec<_>>()
1370 .join("\n");
1371 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1372 header: "Error interacting with language model".into(),
1373 message: SharedString::from(error_message.clone()),
1374 }));
1375 }
1376
1377 thread.cancel_last_completion(cx);
1378 }
1379 }
1380 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1381
1382 if let Some((request_callback, (request, response_events))) = thread
1383 .request_callback
1384 .as_mut()
1385 .zip(request_callback_parameters.as_ref())
1386 {
1387 request_callback(request, response_events);
1388 }
1389
1390 thread.auto_capture_telemetry(cx);
1391
1392 if let Ok(initial_usage) = initial_token_usage {
1393 let usage = thread.cumulative_token_usage - initial_usage;
1394
1395 telemetry::event!(
1396 "Assistant Thread Completion",
1397 thread_id = thread.id().to_string(),
1398 prompt_id = prompt_id,
1399 model = model.telemetry_id(),
1400 model_provider = model.provider_id().to_string(),
1401 input_tokens = usage.input_tokens,
1402 output_tokens = usage.output_tokens,
1403 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1404 cache_read_input_tokens = usage.cache_read_input_tokens,
1405 );
1406 }
1407 })
1408 .ok();
1409 });
1410
1411 self.pending_completions.push(PendingCompletion {
1412 id: pending_completion_id,
1413 _task: task,
1414 });
1415 }
1416
1417 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1418 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1419 return;
1420 };
1421
1422 if !model.provider.is_authenticated(cx) {
1423 return;
1424 }
1425
1426 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1427 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1428 If the conversation is about a specific subject, include it in the title. \
1429 Be descriptive. DO NOT speak in the first person.";
1430
1431 let request = self.to_summarize_request(added_user_message.into());
1432
1433 self.pending_summary = cx.spawn(async move |this, cx| {
1434 async move {
1435 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1436 let (mut messages, usage) = stream.await?;
1437
1438 if let Some(usage) = usage {
1439 this.update(cx, |_thread, cx| {
1440 cx.emit(ThreadEvent::UsageUpdated(usage));
1441 })
1442 .ok();
1443 }
1444
1445 let mut new_summary = String::new();
1446 while let Some(message) = messages.stream.next().await {
1447 let text = message?;
1448 let mut lines = text.lines();
1449 new_summary.extend(lines.next());
1450
1451 // Stop if the LLM generated multiple lines.
1452 if lines.next().is_some() {
1453 break;
1454 }
1455 }
1456
1457 this.update(cx, |this, cx| {
1458 if !new_summary.is_empty() {
1459 this.summary = Some(new_summary.into());
1460 }
1461
1462 cx.emit(ThreadEvent::SummaryGenerated);
1463 })?;
1464
1465 anyhow::Ok(())
1466 }
1467 .log_err()
1468 .await
1469 });
1470 }
1471
1472 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1473 let last_message_id = self.messages.last().map(|message| message.id)?;
1474
1475 match &self.detailed_summary_state {
1476 DetailedSummaryState::Generating { message_id, .. }
1477 | DetailedSummaryState::Generated { message_id, .. }
1478 if *message_id == last_message_id =>
1479 {
1480 // Already up-to-date
1481 return None;
1482 }
1483 _ => {}
1484 }
1485
1486 let ConfiguredModel { model, provider } =
1487 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1488
1489 if !provider.is_authenticated(cx) {
1490 return None;
1491 }
1492
1493 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1494 1. A brief overview of what was discussed\n\
1495 2. Key facts or information discovered\n\
1496 3. Outcomes or conclusions reached\n\
1497 4. Any action items or next steps if any\n\
1498 Format it in Markdown with headings and bullet points.";
1499
1500 let request = self.to_summarize_request(added_user_message.into());
1501
1502 let task = cx.spawn(async move |thread, cx| {
1503 let stream = model.stream_completion_text(request, &cx);
1504 let Some(mut messages) = stream.await.log_err() else {
1505 thread
1506 .update(cx, |this, _cx| {
1507 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1508 })
1509 .log_err();
1510
1511 return;
1512 };
1513
1514 let mut new_detailed_summary = String::new();
1515
1516 while let Some(chunk) = messages.stream.next().await {
1517 if let Some(chunk) = chunk.log_err() {
1518 new_detailed_summary.push_str(&chunk);
1519 }
1520 }
1521
1522 thread
1523 .update(cx, |this, _cx| {
1524 this.detailed_summary_state = DetailedSummaryState::Generated {
1525 text: new_detailed_summary.into(),
1526 message_id: last_message_id,
1527 };
1528 })
1529 .log_err();
1530 });
1531
1532 self.detailed_summary_state = DetailedSummaryState::Generating {
1533 message_id: last_message_id,
1534 };
1535
1536 Some(task)
1537 }
1538
1539 pub fn is_generating_detailed_summary(&self) -> bool {
1540 matches!(
1541 self.detailed_summary_state,
1542 DetailedSummaryState::Generating { .. }
1543 )
1544 }
1545
1546 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1547 self.auto_capture_telemetry(cx);
1548 let request = self.to_completion_request(cx);
1549 let messages = Arc::new(request.messages);
1550 let pending_tool_uses = self
1551 .tool_use
1552 .pending_tool_uses()
1553 .into_iter()
1554 .filter(|tool_use| tool_use.status.is_idle())
1555 .cloned()
1556 .collect::<Vec<_>>();
1557
1558 for tool_use in pending_tool_uses.iter() {
1559 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1560 if tool.needs_confirmation(&tool_use.input, cx)
1561 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1562 {
1563 self.tool_use.confirm_tool_use(
1564 tool_use.id.clone(),
1565 tool_use.ui_text.clone(),
1566 tool_use.input.clone(),
1567 messages.clone(),
1568 tool,
1569 );
1570 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1571 } else {
1572 self.run_tool(
1573 tool_use.id.clone(),
1574 tool_use.ui_text.clone(),
1575 tool_use.input.clone(),
1576 &messages,
1577 tool,
1578 cx,
1579 );
1580 }
1581 }
1582 }
1583
1584 pending_tool_uses
1585 }
1586
1587 pub fn run_tool(
1588 &mut self,
1589 tool_use_id: LanguageModelToolUseId,
1590 ui_text: impl Into<SharedString>,
1591 input: serde_json::Value,
1592 messages: &[LanguageModelRequestMessage],
1593 tool: Arc<dyn Tool>,
1594 cx: &mut Context<Thread>,
1595 ) {
1596 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1597 self.tool_use
1598 .run_pending_tool(tool_use_id, ui_text.into(), task);
1599 }
1600
1601 fn spawn_tool_use(
1602 &mut self,
1603 tool_use_id: LanguageModelToolUseId,
1604 messages: &[LanguageModelRequestMessage],
1605 input: serde_json::Value,
1606 tool: Arc<dyn Tool>,
1607 cx: &mut Context<Thread>,
1608 ) -> Task<()> {
1609 let tool_name: Arc<str> = tool.name().into();
1610
1611 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1612 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1613 } else {
1614 tool.run(
1615 input,
1616 messages,
1617 self.project.clone(),
1618 self.action_log.clone(),
1619 cx,
1620 )
1621 };
1622
1623 // Store the card separately if it exists
1624 if let Some(card) = tool_result.card.clone() {
1625 self.tool_use
1626 .insert_tool_result_card(tool_use_id.clone(), card);
1627 }
1628
1629 cx.spawn({
1630 async move |thread: WeakEntity<Thread>, cx| {
1631 let output = tool_result.output.await;
1632
1633 thread
1634 .update(cx, |thread, cx| {
1635 let pending_tool_use = thread.tool_use.insert_tool_output(
1636 tool_use_id.clone(),
1637 tool_name,
1638 output,
1639 cx,
1640 );
1641 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1642 })
1643 .ok();
1644 }
1645 })
1646 }
1647
1648 fn tool_finished(
1649 &mut self,
1650 tool_use_id: LanguageModelToolUseId,
1651 pending_tool_use: Option<PendingToolUse>,
1652 canceled: bool,
1653 cx: &mut Context<Self>,
1654 ) {
1655 if self.all_tools_finished() {
1656 let model_registry = LanguageModelRegistry::read_global(cx);
1657 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1658 self.attach_tool_results(cx);
1659 if !canceled {
1660 self.send_to_model(model, cx);
1661 }
1662 }
1663 }
1664
1665 cx.emit(ThreadEvent::ToolFinished {
1666 tool_use_id,
1667 pending_tool_use,
1668 });
1669 }
1670
1671 /// Insert an empty message to be populated with tool results upon send.
1672 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1673 // Tool results are assumed to be waiting on the next message id, so they will populate
1674 // this empty message before sending to model. Would prefer this to be more straightforward.
1675 self.insert_message(Role::User, vec![], cx);
1676 self.auto_capture_telemetry(cx);
1677 }
1678
1679 /// Cancels the last pending completion, if there are any pending.
1680 ///
1681 /// Returns whether a completion was canceled.
1682 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1683 let canceled = if self.pending_completions.pop().is_some() {
1684 true
1685 } else {
1686 let mut canceled = false;
1687 for pending_tool_use in self.tool_use.cancel_pending() {
1688 canceled = true;
1689 self.tool_finished(
1690 pending_tool_use.id.clone(),
1691 Some(pending_tool_use),
1692 true,
1693 cx,
1694 );
1695 }
1696 canceled
1697 };
1698 self.finalize_pending_checkpoint(cx);
1699 canceled
1700 }
1701
1702 pub fn feedback(&self) -> Option<ThreadFeedback> {
1703 self.feedback
1704 }
1705
1706 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1707 self.message_feedback.get(&message_id).copied()
1708 }
1709
1710 pub fn report_message_feedback(
1711 &mut self,
1712 message_id: MessageId,
1713 feedback: ThreadFeedback,
1714 cx: &mut Context<Self>,
1715 ) -> Task<Result<()>> {
1716 if self.message_feedback.get(&message_id) == Some(&feedback) {
1717 return Task::ready(Ok(()));
1718 }
1719
1720 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1721 let serialized_thread = self.serialize(cx);
1722 let thread_id = self.id().clone();
1723 let client = self.project.read(cx).client();
1724
1725 let enabled_tool_names: Vec<String> = self
1726 .tools()
1727 .read(cx)
1728 .enabled_tools(cx)
1729 .iter()
1730 .map(|tool| tool.name().to_string())
1731 .collect();
1732
1733 self.message_feedback.insert(message_id, feedback);
1734
1735 cx.notify();
1736
1737 let message_content = self
1738 .message(message_id)
1739 .map(|msg| msg.to_string())
1740 .unwrap_or_default();
1741
1742 cx.background_spawn(async move {
1743 let final_project_snapshot = final_project_snapshot.await;
1744 let serialized_thread = serialized_thread.await?;
1745 let thread_data =
1746 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1747
1748 let rating = match feedback {
1749 ThreadFeedback::Positive => "positive",
1750 ThreadFeedback::Negative => "negative",
1751 };
1752 telemetry::event!(
1753 "Assistant Thread Rated",
1754 rating,
1755 thread_id,
1756 enabled_tool_names,
1757 message_id = message_id.0,
1758 message_content,
1759 thread_data,
1760 final_project_snapshot
1761 );
1762 client.telemetry().flush_events();
1763
1764 Ok(())
1765 })
1766 }
1767
1768 pub fn report_feedback(
1769 &mut self,
1770 feedback: ThreadFeedback,
1771 cx: &mut Context<Self>,
1772 ) -> Task<Result<()>> {
1773 let last_assistant_message_id = self
1774 .messages
1775 .iter()
1776 .rev()
1777 .find(|msg| msg.role == Role::Assistant)
1778 .map(|msg| msg.id);
1779
1780 if let Some(message_id) = last_assistant_message_id {
1781 self.report_message_feedback(message_id, feedback, cx)
1782 } else {
1783 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1784 let serialized_thread = self.serialize(cx);
1785 let thread_id = self.id().clone();
1786 let client = self.project.read(cx).client();
1787 self.feedback = Some(feedback);
1788 cx.notify();
1789
1790 cx.background_spawn(async move {
1791 let final_project_snapshot = final_project_snapshot.await;
1792 let serialized_thread = serialized_thread.await?;
1793 let thread_data = serde_json::to_value(serialized_thread)
1794 .unwrap_or_else(|_| serde_json::Value::Null);
1795
1796 let rating = match feedback {
1797 ThreadFeedback::Positive => "positive",
1798 ThreadFeedback::Negative => "negative",
1799 };
1800 telemetry::event!(
1801 "Assistant Thread Rated",
1802 rating,
1803 thread_id,
1804 thread_data,
1805 final_project_snapshot
1806 );
1807 client.telemetry().flush_events();
1808
1809 Ok(())
1810 })
1811 }
1812 }
1813
1814 /// Create a snapshot of the current project state including git information and unsaved buffers.
1815 fn project_snapshot(
1816 project: Entity<Project>,
1817 cx: &mut Context<Self>,
1818 ) -> Task<Arc<ProjectSnapshot>> {
1819 let git_store = project.read(cx).git_store().clone();
1820 let worktree_snapshots: Vec<_> = project
1821 .read(cx)
1822 .visible_worktrees(cx)
1823 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1824 .collect();
1825
1826 cx.spawn(async move |_, cx| {
1827 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1828
1829 let mut unsaved_buffers = Vec::new();
1830 cx.update(|app_cx| {
1831 let buffer_store = project.read(app_cx).buffer_store();
1832 for buffer_handle in buffer_store.read(app_cx).buffers() {
1833 let buffer = buffer_handle.read(app_cx);
1834 if buffer.is_dirty() {
1835 if let Some(file) = buffer.file() {
1836 let path = file.path().to_string_lossy().to_string();
1837 unsaved_buffers.push(path);
1838 }
1839 }
1840 }
1841 })
1842 .ok();
1843
1844 Arc::new(ProjectSnapshot {
1845 worktree_snapshots,
1846 unsaved_buffer_paths: unsaved_buffers,
1847 timestamp: Utc::now(),
1848 })
1849 })
1850 }
1851
1852 fn worktree_snapshot(
1853 worktree: Entity<project::Worktree>,
1854 git_store: Entity<GitStore>,
1855 cx: &App,
1856 ) -> Task<WorktreeSnapshot> {
1857 cx.spawn(async move |cx| {
1858 // Get worktree path and snapshot
1859 let worktree_info = cx.update(|app_cx| {
1860 let worktree = worktree.read(app_cx);
1861 let path = worktree.abs_path().to_string_lossy().to_string();
1862 let snapshot = worktree.snapshot();
1863 (path, snapshot)
1864 });
1865
1866 let Ok((worktree_path, _snapshot)) = worktree_info else {
1867 return WorktreeSnapshot {
1868 worktree_path: String::new(),
1869 git_state: None,
1870 };
1871 };
1872
1873 let git_state = git_store
1874 .update(cx, |git_store, cx| {
1875 git_store
1876 .repositories()
1877 .values()
1878 .find(|repo| {
1879 repo.read(cx)
1880 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1881 .is_some()
1882 })
1883 .cloned()
1884 })
1885 .ok()
1886 .flatten()
1887 .map(|repo| {
1888 repo.update(cx, |repo, _| {
1889 let current_branch =
1890 repo.branch.as_ref().map(|branch| branch.name.to_string());
1891 repo.send_job(None, |state, _| async move {
1892 let RepositoryState::Local { backend, .. } = state else {
1893 return GitState {
1894 remote_url: None,
1895 head_sha: None,
1896 current_branch,
1897 diff: None,
1898 };
1899 };
1900
1901 let remote_url = backend.remote_url("origin");
1902 let head_sha = backend.head_sha();
1903 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1904
1905 GitState {
1906 remote_url,
1907 head_sha,
1908 current_branch,
1909 diff,
1910 }
1911 })
1912 })
1913 });
1914
1915 let git_state = match git_state {
1916 Some(git_state) => match git_state.ok() {
1917 Some(git_state) => git_state.await.ok(),
1918 None => None,
1919 },
1920 None => None,
1921 };
1922
1923 WorktreeSnapshot {
1924 worktree_path,
1925 git_state,
1926 }
1927 })
1928 }
1929
1930 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1931 let mut markdown = Vec::new();
1932
1933 if let Some(summary) = self.summary() {
1934 writeln!(markdown, "# {summary}\n")?;
1935 };
1936
1937 for message in self.messages() {
1938 writeln!(
1939 markdown,
1940 "## {role}\n",
1941 role = match message.role {
1942 Role::User => "User",
1943 Role::Assistant => "Assistant",
1944 Role::System => "System",
1945 }
1946 )?;
1947
1948 if !message.context.is_empty() {
1949 writeln!(markdown, "{}", message.context)?;
1950 }
1951
1952 for segment in &message.segments {
1953 match segment {
1954 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1955 MessageSegment::Thinking { text, .. } => {
1956 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1957 }
1958 MessageSegment::RedactedThinking(_) => {}
1959 }
1960 }
1961
1962 for tool_use in self.tool_uses_for_message(message.id, cx) {
1963 writeln!(
1964 markdown,
1965 "**Use Tool: {} ({})**",
1966 tool_use.name, tool_use.id
1967 )?;
1968 writeln!(markdown, "```json")?;
1969 writeln!(
1970 markdown,
1971 "{}",
1972 serde_json::to_string_pretty(&tool_use.input)?
1973 )?;
1974 writeln!(markdown, "```")?;
1975 }
1976
1977 for tool_result in self.tool_results_for_message(message.id) {
1978 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1979 if tool_result.is_error {
1980 write!(markdown, " (Error)")?;
1981 }
1982
1983 writeln!(markdown, "**\n")?;
1984 writeln!(markdown, "{}", tool_result.content)?;
1985 }
1986 }
1987
1988 Ok(String::from_utf8_lossy(&markdown).to_string())
1989 }
1990
1991 pub fn keep_edits_in_range(
1992 &mut self,
1993 buffer: Entity<language::Buffer>,
1994 buffer_range: Range<language::Anchor>,
1995 cx: &mut Context<Self>,
1996 ) {
1997 self.action_log.update(cx, |action_log, cx| {
1998 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1999 });
2000 }
2001
2002 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2003 self.action_log
2004 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2005 }
2006
2007 pub fn reject_edits_in_ranges(
2008 &mut self,
2009 buffer: Entity<language::Buffer>,
2010 buffer_ranges: Vec<Range<language::Anchor>>,
2011 cx: &mut Context<Self>,
2012 ) -> Task<Result<()>> {
2013 self.action_log.update(cx, |action_log, cx| {
2014 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2015 })
2016 }
2017
2018 pub fn action_log(&self) -> &Entity<ActionLog> {
2019 &self.action_log
2020 }
2021
2022 pub fn project(&self) -> &Entity<Project> {
2023 &self.project
2024 }
2025
2026 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2027 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2028 return;
2029 }
2030
2031 let now = Instant::now();
2032 if let Some(last) = self.last_auto_capture_at {
2033 if now.duration_since(last).as_secs() < 10 {
2034 return;
2035 }
2036 }
2037
2038 self.last_auto_capture_at = Some(now);
2039
2040 let thread_id = self.id().clone();
2041 let github_login = self
2042 .project
2043 .read(cx)
2044 .user_store()
2045 .read(cx)
2046 .current_user()
2047 .map(|user| user.github_login.clone());
2048 let client = self.project.read(cx).client().clone();
2049 let serialize_task = self.serialize(cx);
2050
2051 cx.background_executor()
2052 .spawn(async move {
2053 if let Ok(serialized_thread) = serialize_task.await {
2054 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2055 telemetry::event!(
2056 "Agent Thread Auto-Captured",
2057 thread_id = thread_id.to_string(),
2058 thread_data = thread_data,
2059 auto_capture_reason = "tracked_user",
2060 github_login = github_login
2061 );
2062
2063 client.telemetry().flush_events();
2064 }
2065 }
2066 })
2067 .detach();
2068 }
2069
2070 pub fn cumulative_token_usage(&self) -> TokenUsage {
2071 self.cumulative_token_usage
2072 }
2073
2074 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2075 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2076 return TotalTokenUsage::default();
2077 };
2078
2079 let max = model.model.max_token_count();
2080
2081 let index = self
2082 .messages
2083 .iter()
2084 .position(|msg| msg.id == message_id)
2085 .unwrap_or(0);
2086
2087 if index == 0 {
2088 return TotalTokenUsage { total: 0, max };
2089 }
2090
2091 let token_usage = &self
2092 .request_token_usage
2093 .get(index - 1)
2094 .cloned()
2095 .unwrap_or_default();
2096
2097 TotalTokenUsage {
2098 total: token_usage.total_tokens() as usize,
2099 max,
2100 }
2101 }
2102
2103 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2104 let model_registry = LanguageModelRegistry::read_global(cx);
2105 let Some(model) = model_registry.default_model() else {
2106 return TotalTokenUsage::default();
2107 };
2108
2109 let max = model.model.max_token_count();
2110
2111 if let Some(exceeded_error) = &self.exceeded_window_error {
2112 if model.model.id() == exceeded_error.model_id {
2113 return TotalTokenUsage {
2114 total: exceeded_error.token_count,
2115 max,
2116 };
2117 }
2118 }
2119
2120 let total = self
2121 .token_usage_at_last_message()
2122 .unwrap_or_default()
2123 .total_tokens() as usize;
2124
2125 TotalTokenUsage { total, max }
2126 }
2127
2128 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2129 self.request_token_usage
2130 .get(self.messages.len().saturating_sub(1))
2131 .or_else(|| self.request_token_usage.last())
2132 .cloned()
2133 }
2134
2135 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2136 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2137 self.request_token_usage
2138 .resize(self.messages.len(), placeholder);
2139
2140 if let Some(last) = self.request_token_usage.last_mut() {
2141 *last = token_usage;
2142 }
2143 }
2144
2145 pub fn deny_tool_use(
2146 &mut self,
2147 tool_use_id: LanguageModelToolUseId,
2148 tool_name: Arc<str>,
2149 cx: &mut Context<Self>,
2150 ) {
2151 let err = Err(anyhow::anyhow!(
2152 "Permission to run tool action denied by user"
2153 ));
2154
2155 self.tool_use
2156 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2157 self.tool_finished(tool_use_id.clone(), None, true, cx);
2158 }
2159}
2160
2161#[derive(Debug, Clone, Error)]
2162pub enum ThreadError {
2163 #[error("Payment required")]
2164 PaymentRequired,
2165 #[error("Max monthly spend reached")]
2166 MaxMonthlySpendReached,
2167 #[error("Model request limit reached")]
2168 ModelRequestLimitReached { plan: Plan },
2169 #[error("Message {header}: {message}")]
2170 Message {
2171 header: SharedString,
2172 message: SharedString,
2173 },
2174}
2175
2176#[derive(Debug, Clone)]
2177pub enum ThreadEvent {
2178 ShowError(ThreadError),
2179 UsageUpdated(RequestUsage),
2180 StreamedCompletion,
2181 StreamedAssistantText(MessageId, String),
2182 StreamedAssistantThinking(MessageId, String),
2183 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2184 MessageAdded(MessageId),
2185 MessageEdited(MessageId),
2186 MessageDeleted(MessageId),
2187 SummaryGenerated,
2188 SummaryChanged,
2189 UsePendingTools {
2190 tool_uses: Vec<PendingToolUse>,
2191 },
2192 ToolFinished {
2193 #[allow(unused)]
2194 tool_use_id: LanguageModelToolUseId,
2195 /// The pending tool use that corresponds to this tool.
2196 pending_tool_use: Option<PendingToolUse>,
2197 },
2198 CheckpointChanged,
2199 ToolConfirmationNeeded,
2200}
2201
2202impl EventEmitter<ThreadEvent> for Thread {}
2203
2204struct PendingCompletion {
2205 id: usize,
2206 _task: Task<()>,
2207}
2208
2209#[cfg(test)]
2210mod tests {
2211 use super::*;
2212 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2213 use assistant_settings::AssistantSettings;
2214 use context_server::ContextServerSettings;
2215 use editor::EditorSettings;
2216 use gpui::TestAppContext;
2217 use project::{FakeFs, Project};
2218 use prompt_store::PromptBuilder;
2219 use serde_json::json;
2220 use settings::{Settings, SettingsStore};
2221 use std::sync::Arc;
2222 use theme::ThemeSettings;
2223 use util::path;
2224 use workspace::Workspace;
2225
2226 #[gpui::test]
2227 async fn test_message_with_context(cx: &mut TestAppContext) {
2228 init_test_settings(cx);
2229
2230 let project = create_test_project(
2231 cx,
2232 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2233 )
2234 .await;
2235
2236 let (_workspace, _thread_store, thread, context_store) =
2237 setup_test_environment(cx, project.clone()).await;
2238
2239 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2240 .await
2241 .unwrap();
2242
2243 let context =
2244 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2245
2246 // Insert user message with context
2247 let message_id = thread.update(cx, |thread, cx| {
2248 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2249 });
2250
2251 // Check content and context in message object
2252 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2253
2254 // Use different path format strings based on platform for the test
2255 #[cfg(windows)]
2256 let path_part = r"test\code.rs";
2257 #[cfg(not(windows))]
2258 let path_part = "test/code.rs";
2259
2260 let expected_context = format!(
2261 r#"
2262<context>
2263The following items were attached by the user. You don't need to use other tools to read them.
2264
2265<files>
2266```rs {path_part}
2267fn main() {{
2268 println!("Hello, world!");
2269}}
2270```
2271</files>
2272</context>
2273"#
2274 );
2275
2276 assert_eq!(message.role, Role::User);
2277 assert_eq!(message.segments.len(), 1);
2278 assert_eq!(
2279 message.segments[0],
2280 MessageSegment::Text("Please explain this code".to_string())
2281 );
2282 assert_eq!(message.context, expected_context);
2283
2284 // Check message in request
2285 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2286
2287 assert_eq!(request.messages.len(), 2);
2288 let expected_full_message = format!("{}Please explain this code", expected_context);
2289 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2290 }
2291
2292 #[gpui::test]
2293 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2294 init_test_settings(cx);
2295
2296 let project = create_test_project(
2297 cx,
2298 json!({
2299 "file1.rs": "fn function1() {}\n",
2300 "file2.rs": "fn function2() {}\n",
2301 "file3.rs": "fn function3() {}\n",
2302 }),
2303 )
2304 .await;
2305
2306 let (_, _thread_store, thread, context_store) =
2307 setup_test_environment(cx, project.clone()).await;
2308
2309 // Open files individually
2310 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2311 .await
2312 .unwrap();
2313 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2314 .await
2315 .unwrap();
2316 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2317 .await
2318 .unwrap();
2319
2320 // Get the context objects
2321 let contexts = context_store.update(cx, |store, _| store.context().clone());
2322 assert_eq!(contexts.len(), 3);
2323
2324 // First message with context 1
2325 let message1_id = thread.update(cx, |thread, cx| {
2326 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2327 });
2328
2329 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2330 let message2_id = thread.update(cx, |thread, cx| {
2331 thread.insert_user_message(
2332 "Message 2",
2333 vec![contexts[0].clone(), contexts[1].clone()],
2334 None,
2335 cx,
2336 )
2337 });
2338
2339 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2340 let message3_id = thread.update(cx, |thread, cx| {
2341 thread.insert_user_message(
2342 "Message 3",
2343 vec![
2344 contexts[0].clone(),
2345 contexts[1].clone(),
2346 contexts[2].clone(),
2347 ],
2348 None,
2349 cx,
2350 )
2351 });
2352
2353 // Check what contexts are included in each message
2354 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2355 (
2356 thread.message(message1_id).unwrap().clone(),
2357 thread.message(message2_id).unwrap().clone(),
2358 thread.message(message3_id).unwrap().clone(),
2359 )
2360 });
2361
2362 // First message should include context 1
2363 assert!(message1.context.contains("file1.rs"));
2364
2365 // Second message should include only context 2 (not 1)
2366 assert!(!message2.context.contains("file1.rs"));
2367 assert!(message2.context.contains("file2.rs"));
2368
2369 // Third message should include only context 3 (not 1 or 2)
2370 assert!(!message3.context.contains("file1.rs"));
2371 assert!(!message3.context.contains("file2.rs"));
2372 assert!(message3.context.contains("file3.rs"));
2373
2374 // Check entire request to make sure all contexts are properly included
2375 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2376
2377 // The request should contain all 3 messages
2378 assert_eq!(request.messages.len(), 4);
2379
2380 // Check that the contexts are properly formatted in each message
2381 assert!(request.messages[1].string_contents().contains("file1.rs"));
2382 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2383 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2384
2385 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2386 assert!(request.messages[2].string_contents().contains("file2.rs"));
2387 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2388
2389 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2390 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2391 assert!(request.messages[3].string_contents().contains("file3.rs"));
2392 }
2393
2394 #[gpui::test]
2395 async fn test_message_without_files(cx: &mut TestAppContext) {
2396 init_test_settings(cx);
2397
2398 let project = create_test_project(
2399 cx,
2400 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2401 )
2402 .await;
2403
2404 let (_, _thread_store, thread, _context_store) =
2405 setup_test_environment(cx, project.clone()).await;
2406
2407 // Insert user message without any context (empty context vector)
2408 let message_id = thread.update(cx, |thread, cx| {
2409 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2410 });
2411
2412 // Check content and context in message object
2413 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2414
2415 // Context should be empty when no files are included
2416 assert_eq!(message.role, Role::User);
2417 assert_eq!(message.segments.len(), 1);
2418 assert_eq!(
2419 message.segments[0],
2420 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2421 );
2422 assert_eq!(message.context, "");
2423
2424 // Check message in request
2425 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2426
2427 assert_eq!(request.messages.len(), 2);
2428 assert_eq!(
2429 request.messages[1].string_contents(),
2430 "What is the best way to learn Rust?"
2431 );
2432
2433 // Add second message, also without context
2434 let message2_id = thread.update(cx, |thread, cx| {
2435 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2436 });
2437
2438 let message2 =
2439 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2440 assert_eq!(message2.context, "");
2441
2442 // Check that both messages appear in the request
2443 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2444
2445 assert_eq!(request.messages.len(), 3);
2446 assert_eq!(
2447 request.messages[1].string_contents(),
2448 "What is the best way to learn Rust?"
2449 );
2450 assert_eq!(
2451 request.messages[2].string_contents(),
2452 "Are there any good books?"
2453 );
2454 }
2455
2456 #[gpui::test]
2457 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2458 init_test_settings(cx);
2459
2460 let project = create_test_project(
2461 cx,
2462 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2463 )
2464 .await;
2465
2466 let (_workspace, _thread_store, thread, context_store) =
2467 setup_test_environment(cx, project.clone()).await;
2468
2469 // Open buffer and add it to context
2470 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2471 .await
2472 .unwrap();
2473
2474 let context =
2475 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2476
2477 // Insert user message with the buffer as context
2478 thread.update(cx, |thread, cx| {
2479 thread.insert_user_message("Explain this code", vec![context], None, cx)
2480 });
2481
2482 // Create a request and check that it doesn't have a stale buffer warning yet
2483 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2484
2485 // Make sure we don't have a stale file warning yet
2486 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2487 msg.string_contents()
2488 .contains("These files changed since last read:")
2489 });
2490 assert!(
2491 !has_stale_warning,
2492 "Should not have stale buffer warning before buffer is modified"
2493 );
2494
2495 // Modify the buffer
2496 buffer.update(cx, |buffer, cx| {
2497 // Find a position at the end of line 1
2498 buffer.edit(
2499 [(1..1, "\n println!(\"Added a new line\");\n")],
2500 None,
2501 cx,
2502 );
2503 });
2504
2505 // Insert another user message without context
2506 thread.update(cx, |thread, cx| {
2507 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2508 });
2509
2510 // Create a new request and check for the stale buffer warning
2511 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2512
2513 // We should have a stale file warning as the last message
2514 let last_message = new_request
2515 .messages
2516 .last()
2517 .expect("Request should have messages");
2518
2519 // The last message should be the stale buffer notification
2520 assert_eq!(last_message.role, Role::User);
2521
2522 // Check the exact content of the message
2523 let expected_content = "These files changed since last read:\n- code.rs\n";
2524 assert_eq!(
2525 last_message.string_contents(),
2526 expected_content,
2527 "Last message should be exactly the stale buffer notification"
2528 );
2529 }
2530
2531 fn init_test_settings(cx: &mut TestAppContext) {
2532 cx.update(|cx| {
2533 let settings_store = SettingsStore::test(cx);
2534 cx.set_global(settings_store);
2535 language::init(cx);
2536 Project::init_settings(cx);
2537 AssistantSettings::register(cx);
2538 prompt_store::init(cx);
2539 thread_store::init(cx);
2540 workspace::init_settings(cx);
2541 ThemeSettings::register(cx);
2542 ContextServerSettings::register(cx);
2543 EditorSettings::register(cx);
2544 });
2545 }
2546
2547 // Helper to create a test project with test files
2548 async fn create_test_project(
2549 cx: &mut TestAppContext,
2550 files: serde_json::Value,
2551 ) -> Entity<Project> {
2552 let fs = FakeFs::new(cx.executor());
2553 fs.insert_tree(path!("/test"), files).await;
2554 Project::test(fs, [path!("/test").as_ref()], cx).await
2555 }
2556
2557 async fn setup_test_environment(
2558 cx: &mut TestAppContext,
2559 project: Entity<Project>,
2560 ) -> (
2561 Entity<Workspace>,
2562 Entity<ThreadStore>,
2563 Entity<Thread>,
2564 Entity<ContextStore>,
2565 ) {
2566 let (workspace, cx) =
2567 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2568
2569 let thread_store = cx
2570 .update(|_, cx| {
2571 ThreadStore::load(
2572 project.clone(),
2573 cx.new(|_| ToolWorkingSet::default()),
2574 Arc::new(PromptBuilder::new(None).unwrap()),
2575 cx,
2576 )
2577 })
2578 .await
2579 .unwrap();
2580
2581 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2582 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2583
2584 (workspace, thread_store, thread, context_store)
2585 }
2586
2587 async fn add_file_to_context(
2588 project: &Entity<Project>,
2589 context_store: &Entity<ContextStore>,
2590 path: &str,
2591 cx: &mut TestAppContext,
2592 ) -> Result<Entity<language::Buffer>> {
2593 let buffer_path = project
2594 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2595 .unwrap();
2596
2597 let buffer = project
2598 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2599 .await
2600 .unwrap();
2601
2602 context_store
2603 .update(cx, |store, cx| {
2604 store.add_file_from_buffer(buffer.clone(), cx)
2605 })
2606 .await?;
2607
2608 Ok(buffer)
2609 }
2610}