1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// A message in a [`Thread`].
101#[derive(Debug, Clone)]
102pub struct Message {
103 pub id: MessageId,
104 pub role: Role,
105 pub segments: Vec<MessageSegment>,
106 pub context: String,
107}
108
109impl Message {
110 /// Returns whether the message contains any meaningful text that should be displayed
111 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
112 pub fn should_display_content(&self) -> bool {
113 self.segments.iter().all(|segment| segment.should_display())
114 }
115
116 pub fn push_thinking(&mut self, text: &str) {
117 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
118 segment.push_str(text);
119 } else {
120 self.segments
121 .push(MessageSegment::Thinking(text.to_string()));
122 }
123 }
124
125 pub fn push_text(&mut self, text: &str) {
126 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
127 segment.push_str(text);
128 } else {
129 self.segments.push(MessageSegment::Text(text.to_string()));
130 }
131 }
132
133 pub fn to_string(&self) -> String {
134 let mut result = String::new();
135
136 if !self.context.is_empty() {
137 result.push_str(&self.context);
138 }
139
140 for segment in &self.segments {
141 match segment {
142 MessageSegment::Text(text) => result.push_str(text),
143 MessageSegment::Thinking(text) => {
144 result.push_str("<think>");
145 result.push_str(text);
146 result.push_str("</think>");
147 }
148 }
149 }
150
151 result
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
156pub enum MessageSegment {
157 Text(String),
158 Thinking(String),
159}
160
161impl MessageSegment {
162 pub fn text_mut(&mut self) -> &mut String {
163 match self {
164 Self::Text(text) => text,
165 Self::Thinking(text) => text,
166 }
167 }
168
169 pub fn should_display(&self) -> bool {
170 // We add USING_TOOL_MARKER when making a request that includes tool uses
171 // without non-whitespace text around them, and this can cause the model
172 // to mimic the pattern, so we consider those segments not displayable.
173 match self {
174 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
175 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
176 }
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ProjectSnapshot {
182 pub worktree_snapshots: Vec<WorktreeSnapshot>,
183 pub unsaved_buffer_paths: Vec<String>,
184 pub timestamp: DateTime<Utc>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct WorktreeSnapshot {
189 pub worktree_path: String,
190 pub git_state: Option<GitState>,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct GitState {
195 pub remote_url: Option<String>,
196 pub head_sha: Option<String>,
197 pub current_branch: Option<String>,
198 pub diff: Option<String>,
199}
200
201#[derive(Clone)]
202pub struct ThreadCheckpoint {
203 message_id: MessageId,
204 git_checkpoint: GitStoreCheckpoint,
205}
206
207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
208pub enum ThreadFeedback {
209 Positive,
210 Negative,
211}
212
213pub enum LastRestoreCheckpoint {
214 Pending {
215 message_id: MessageId,
216 },
217 Error {
218 message_id: MessageId,
219 error: String,
220 },
221}
222
223impl LastRestoreCheckpoint {
224 pub fn message_id(&self) -> MessageId {
225 match self {
226 LastRestoreCheckpoint::Pending { message_id } => *message_id,
227 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
228 }
229 }
230}
231
232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
233pub enum DetailedSummaryState {
234 #[default]
235 NotGenerated,
236 Generating {
237 message_id: MessageId,
238 },
239 Generated {
240 text: SharedString,
241 message_id: MessageId,
242 },
243}
244
245#[derive(Default)]
246pub struct TotalTokenUsage {
247 pub total: usize,
248 pub max: usize,
249}
250
251impl TotalTokenUsage {
252 pub fn ratio(&self) -> TokenUsageRatio {
253 #[cfg(debug_assertions)]
254 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
255 .unwrap_or("0.8".to_string())
256 .parse()
257 .unwrap();
258 #[cfg(not(debug_assertions))]
259 let warning_threshold: f32 = 0.8;
260
261 if self.total >= self.max {
262 TokenUsageRatio::Exceeded
263 } else if self.total as f32 / self.max as f32 >= warning_threshold {
264 TokenUsageRatio::Warning
265 } else {
266 TokenUsageRatio::Normal
267 }
268 }
269
270 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
271 TotalTokenUsage {
272 total: self.total + tokens,
273 max: self.max,
274 }
275 }
276}
277
278#[derive(Debug, Default, PartialEq, Eq)]
279pub enum TokenUsageRatio {
280 #[default]
281 Normal,
282 Warning,
283 Exceeded,
284}
285
286/// A thread of conversation with the LLM.
287pub struct Thread {
288 id: ThreadId,
289 updated_at: DateTime<Utc>,
290 summary: Option<SharedString>,
291 pending_summary: Task<Option<()>>,
292 detailed_summary_state: DetailedSummaryState,
293 messages: Vec<Message>,
294 next_message_id: MessageId,
295 last_prompt_id: PromptId,
296 context: BTreeMap<ContextId, AssistantContext>,
297 context_by_message: HashMap<MessageId, Vec<ContextId>>,
298 project_context: SharedProjectContext,
299 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
300 completion_count: usize,
301 pending_completions: Vec<PendingCompletion>,
302 project: Entity<Project>,
303 prompt_builder: Arc<PromptBuilder>,
304 tools: Entity<ToolWorkingSet>,
305 tool_use: ToolUseState,
306 action_log: Entity<ActionLog>,
307 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
308 pending_checkpoint: Option<ThreadCheckpoint>,
309 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
310 request_token_usage: Vec<TokenUsage>,
311 cumulative_token_usage: TokenUsage,
312 exceeded_window_error: Option<ExceededWindowError>,
313 feedback: Option<ThreadFeedback>,
314 message_feedback: HashMap<MessageId, ThreadFeedback>,
315 last_auto_capture_at: Option<Instant>,
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct ExceededWindowError {
320 /// Model used when last message exceeded context window
321 model_id: LanguageModelId,
322 /// Token count including last message
323 token_count: usize,
324}
325
326impl Thread {
327 pub fn new(
328 project: Entity<Project>,
329 tools: Entity<ToolWorkingSet>,
330 prompt_builder: Arc<PromptBuilder>,
331 system_prompt: SharedProjectContext,
332 cx: &mut Context<Self>,
333 ) -> Self {
334 Self {
335 id: ThreadId::new(),
336 updated_at: Utc::now(),
337 summary: None,
338 pending_summary: Task::ready(None),
339 detailed_summary_state: DetailedSummaryState::NotGenerated,
340 messages: Vec::new(),
341 next_message_id: MessageId(0),
342 last_prompt_id: PromptId::new(),
343 context: BTreeMap::default(),
344 context_by_message: HashMap::default(),
345 project_context: system_prompt,
346 checkpoints_by_message: HashMap::default(),
347 completion_count: 0,
348 pending_completions: Vec::new(),
349 project: project.clone(),
350 prompt_builder,
351 tools: tools.clone(),
352 last_restore_checkpoint: None,
353 pending_checkpoint: None,
354 tool_use: ToolUseState::new(tools.clone()),
355 action_log: cx.new(|_| ActionLog::new(project.clone())),
356 initial_project_snapshot: {
357 let project_snapshot = Self::project_snapshot(project, cx);
358 cx.foreground_executor()
359 .spawn(async move { Some(project_snapshot.await) })
360 .shared()
361 },
362 request_token_usage: Vec::new(),
363 cumulative_token_usage: TokenUsage::default(),
364 exceeded_window_error: None,
365 feedback: None,
366 message_feedback: HashMap::default(),
367 last_auto_capture_at: None,
368 }
369 }
370
371 pub fn deserialize(
372 id: ThreadId,
373 serialized: SerializedThread,
374 project: Entity<Project>,
375 tools: Entity<ToolWorkingSet>,
376 prompt_builder: Arc<PromptBuilder>,
377 project_context: SharedProjectContext,
378 cx: &mut Context<Self>,
379 ) -> Self {
380 let next_message_id = MessageId(
381 serialized
382 .messages
383 .last()
384 .map(|message| message.id.0 + 1)
385 .unwrap_or(0),
386 );
387 let tool_use =
388 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
389
390 Self {
391 id,
392 updated_at: serialized.updated_at,
393 summary: Some(serialized.summary),
394 pending_summary: Task::ready(None),
395 detailed_summary_state: serialized.detailed_summary_state,
396 messages: serialized
397 .messages
398 .into_iter()
399 .map(|message| Message {
400 id: message.id,
401 role: message.role,
402 segments: message
403 .segments
404 .into_iter()
405 .map(|segment| match segment {
406 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
407 SerializedMessageSegment::Thinking { text } => {
408 MessageSegment::Thinking(text)
409 }
410 })
411 .collect(),
412 context: message.context,
413 })
414 .collect(),
415 next_message_id,
416 last_prompt_id: PromptId::new(),
417 context: BTreeMap::default(),
418 context_by_message: HashMap::default(),
419 project_context,
420 checkpoints_by_message: HashMap::default(),
421 completion_count: 0,
422 pending_completions: Vec::new(),
423 last_restore_checkpoint: None,
424 pending_checkpoint: None,
425 project: project.clone(),
426 prompt_builder,
427 tools,
428 tool_use,
429 action_log: cx.new(|_| ActionLog::new(project)),
430 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
431 request_token_usage: serialized.request_token_usage,
432 cumulative_token_usage: serialized.cumulative_token_usage,
433 exceeded_window_error: None,
434 feedback: None,
435 message_feedback: HashMap::default(),
436 last_auto_capture_at: None,
437 }
438 }
439
440 pub fn id(&self) -> &ThreadId {
441 &self.id
442 }
443
444 pub fn is_empty(&self) -> bool {
445 self.messages.is_empty()
446 }
447
448 pub fn updated_at(&self) -> DateTime<Utc> {
449 self.updated_at
450 }
451
452 pub fn touch_updated_at(&mut self) {
453 self.updated_at = Utc::now();
454 }
455
456 pub fn advance_prompt_id(&mut self) {
457 self.last_prompt_id = PromptId::new();
458 }
459
460 pub fn summary(&self) -> Option<SharedString> {
461 self.summary.clone()
462 }
463
464 pub fn project_context(&self) -> SharedProjectContext {
465 self.project_context.clone()
466 }
467
468 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
469
470 pub fn summary_or_default(&self) -> SharedString {
471 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
472 }
473
474 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
475 let Some(current_summary) = &self.summary else {
476 // Don't allow setting summary until generated
477 return;
478 };
479
480 let mut new_summary = new_summary.into();
481
482 if new_summary.is_empty() {
483 new_summary = Self::DEFAULT_SUMMARY;
484 }
485
486 if current_summary != &new_summary {
487 self.summary = Some(new_summary);
488 cx.emit(ThreadEvent::SummaryChanged);
489 }
490 }
491
492 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
493 self.latest_detailed_summary()
494 .unwrap_or_else(|| self.text().into())
495 }
496
497 fn latest_detailed_summary(&self) -> Option<SharedString> {
498 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
499 Some(text.clone())
500 } else {
501 None
502 }
503 }
504
505 pub fn message(&self, id: MessageId) -> Option<&Message> {
506 self.messages.iter().find(|message| message.id == id)
507 }
508
509 pub fn messages(&self) -> impl Iterator<Item = &Message> {
510 self.messages.iter()
511 }
512
513 pub fn is_generating(&self) -> bool {
514 !self.pending_completions.is_empty() || !self.all_tools_finished()
515 }
516
517 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
518 &self.tools
519 }
520
521 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
522 self.tool_use
523 .pending_tool_uses()
524 .into_iter()
525 .find(|tool_use| &tool_use.id == id)
526 }
527
528 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
529 self.tool_use
530 .pending_tool_uses()
531 .into_iter()
532 .filter(|tool_use| tool_use.status.needs_confirmation())
533 }
534
535 pub fn has_pending_tool_uses(&self) -> bool {
536 !self.tool_use.pending_tool_uses().is_empty()
537 }
538
539 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
540 self.checkpoints_by_message.get(&id).cloned()
541 }
542
543 pub fn restore_checkpoint(
544 &mut self,
545 checkpoint: ThreadCheckpoint,
546 cx: &mut Context<Self>,
547 ) -> Task<Result<()>> {
548 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
549 message_id: checkpoint.message_id,
550 });
551 cx.emit(ThreadEvent::CheckpointChanged);
552 cx.notify();
553
554 let git_store = self.project().read(cx).git_store().clone();
555 let restore = git_store.update(cx, |git_store, cx| {
556 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
557 });
558
559 cx.spawn(async move |this, cx| {
560 let result = restore.await;
561 this.update(cx, |this, cx| {
562 if let Err(err) = result.as_ref() {
563 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
564 message_id: checkpoint.message_id,
565 error: err.to_string(),
566 });
567 } else {
568 this.truncate(checkpoint.message_id, cx);
569 this.last_restore_checkpoint = None;
570 }
571 this.pending_checkpoint = None;
572 cx.emit(ThreadEvent::CheckpointChanged);
573 cx.notify();
574 })?;
575 result
576 })
577 }
578
579 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
580 let pending_checkpoint = if self.is_generating() {
581 return;
582 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
583 checkpoint
584 } else {
585 return;
586 };
587
588 let git_store = self.project.read(cx).git_store().clone();
589 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
590 cx.spawn(async move |this, cx| match final_checkpoint.await {
591 Ok(final_checkpoint) => {
592 let equal = git_store
593 .update(cx, |store, cx| {
594 store.compare_checkpoints(
595 pending_checkpoint.git_checkpoint.clone(),
596 final_checkpoint.clone(),
597 cx,
598 )
599 })?
600 .await
601 .unwrap_or(false);
602
603 if equal {
604 git_store
605 .update(cx, |store, cx| {
606 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
607 })?
608 .detach();
609 } else {
610 this.update(cx, |this, cx| {
611 this.insert_checkpoint(pending_checkpoint, cx)
612 })?;
613 }
614
615 git_store
616 .update(cx, |store, cx| {
617 store.delete_checkpoint(final_checkpoint, cx)
618 })?
619 .detach();
620
621 Ok(())
622 }
623 Err(_) => this.update(cx, |this, cx| {
624 this.insert_checkpoint(pending_checkpoint, cx)
625 }),
626 })
627 .detach();
628 }
629
630 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
631 self.checkpoints_by_message
632 .insert(checkpoint.message_id, checkpoint);
633 cx.emit(ThreadEvent::CheckpointChanged);
634 cx.notify();
635 }
636
637 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
638 self.last_restore_checkpoint.as_ref()
639 }
640
641 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
642 let Some(message_ix) = self
643 .messages
644 .iter()
645 .rposition(|message| message.id == message_id)
646 else {
647 return;
648 };
649 for deleted_message in self.messages.drain(message_ix..) {
650 self.context_by_message.remove(&deleted_message.id);
651 self.checkpoints_by_message.remove(&deleted_message.id);
652 }
653 cx.notify();
654 }
655
656 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
657 self.context_by_message
658 .get(&id)
659 .into_iter()
660 .flat_map(|context| {
661 context
662 .iter()
663 .filter_map(|context_id| self.context.get(&context_id))
664 })
665 }
666
667 /// Returns whether all of the tool uses have finished running.
668 pub fn all_tools_finished(&self) -> bool {
669 // If the only pending tool uses left are the ones with errors, then
670 // that means that we've finished running all of the pending tools.
671 self.tool_use
672 .pending_tool_uses()
673 .iter()
674 .all(|tool_use| tool_use.status.is_error())
675 }
676
677 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
678 self.tool_use.tool_uses_for_message(id, cx)
679 }
680
681 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
682 self.tool_use.tool_results_for_message(id)
683 }
684
685 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
686 self.tool_use.tool_result(id)
687 }
688
689 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
690 Some(&self.tool_use.tool_result(id)?.content)
691 }
692
693 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
694 self.tool_use.tool_result_card(id).cloned()
695 }
696
697 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
698 self.tool_use.message_has_tool_results(message_id)
699 }
700
701 /// Filter out contexts that have already been included in previous messages
702 pub fn filter_new_context<'a>(
703 &self,
704 context: impl Iterator<Item = &'a AssistantContext>,
705 ) -> impl Iterator<Item = &'a AssistantContext> {
706 context.filter(|ctx| self.is_context_new(ctx))
707 }
708
709 fn is_context_new(&self, context: &AssistantContext) -> bool {
710 !self.context.contains_key(&context.id())
711 }
712
713 pub fn insert_user_message(
714 &mut self,
715 text: impl Into<String>,
716 context: Vec<AssistantContext>,
717 git_checkpoint: Option<GitStoreCheckpoint>,
718 cx: &mut Context<Self>,
719 ) -> MessageId {
720 let text = text.into();
721
722 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
723
724 let new_context: Vec<_> = context
725 .into_iter()
726 .filter(|ctx| self.is_context_new(ctx))
727 .collect();
728
729 if !new_context.is_empty() {
730 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
731 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
732 message.context = context_string;
733 }
734 }
735
736 self.action_log.update(cx, |log, cx| {
737 // Track all buffers added as context
738 for ctx in &new_context {
739 match ctx {
740 AssistantContext::File(file_ctx) => {
741 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
742 }
743 AssistantContext::Directory(dir_ctx) => {
744 for context_buffer in &dir_ctx.context_buffers {
745 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
746 }
747 }
748 AssistantContext::Symbol(symbol_ctx) => {
749 log.buffer_added_as_context(
750 symbol_ctx.context_symbol.buffer.clone(),
751 cx,
752 );
753 }
754 AssistantContext::Excerpt(excerpt_context) => {
755 log.buffer_added_as_context(
756 excerpt_context.context_buffer.buffer.clone(),
757 cx,
758 );
759 }
760 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
761 }
762 }
763 });
764 }
765
766 let context_ids = new_context
767 .iter()
768 .map(|context| context.id())
769 .collect::<Vec<_>>();
770 self.context.extend(
771 new_context
772 .into_iter()
773 .map(|context| (context.id(), context)),
774 );
775 self.context_by_message.insert(message_id, context_ids);
776
777 if let Some(git_checkpoint) = git_checkpoint {
778 self.pending_checkpoint = Some(ThreadCheckpoint {
779 message_id,
780 git_checkpoint,
781 });
782 }
783
784 self.auto_capture_telemetry(cx);
785
786 message_id
787 }
788
789 pub fn insert_message(
790 &mut self,
791 role: Role,
792 segments: Vec<MessageSegment>,
793 cx: &mut Context<Self>,
794 ) -> MessageId {
795 let id = self.next_message_id.post_inc();
796 self.messages.push(Message {
797 id,
798 role,
799 segments,
800 context: String::new(),
801 });
802 self.touch_updated_at();
803 cx.emit(ThreadEvent::MessageAdded(id));
804 id
805 }
806
807 pub fn edit_message(
808 &mut self,
809 id: MessageId,
810 new_role: Role,
811 new_segments: Vec<MessageSegment>,
812 cx: &mut Context<Self>,
813 ) -> bool {
814 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
815 return false;
816 };
817 message.role = new_role;
818 message.segments = new_segments;
819 self.touch_updated_at();
820 cx.emit(ThreadEvent::MessageEdited(id));
821 true
822 }
823
824 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
825 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
826 return false;
827 };
828 self.messages.remove(index);
829 self.context_by_message.remove(&id);
830 self.touch_updated_at();
831 cx.emit(ThreadEvent::MessageDeleted(id));
832 true
833 }
834
835 /// Returns the representation of this [`Thread`] in a textual form.
836 ///
837 /// This is the representation we use when attaching a thread as context to another thread.
838 pub fn text(&self) -> String {
839 let mut text = String::new();
840
841 for message in &self.messages {
842 text.push_str(match message.role {
843 language_model::Role::User => "User:",
844 language_model::Role::Assistant => "Assistant:",
845 language_model::Role::System => "System:",
846 });
847 text.push('\n');
848
849 for segment in &message.segments {
850 match segment {
851 MessageSegment::Text(content) => text.push_str(content),
852 MessageSegment::Thinking(content) => {
853 text.push_str(&format!("<think>{}</think>", content))
854 }
855 }
856 }
857 text.push('\n');
858 }
859
860 text
861 }
862
863 /// Serializes this thread into a format for storage or telemetry.
864 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
865 let initial_project_snapshot = self.initial_project_snapshot.clone();
866 cx.spawn(async move |this, cx| {
867 let initial_project_snapshot = initial_project_snapshot.await;
868 this.read_with(cx, |this, cx| SerializedThread {
869 version: SerializedThread::VERSION.to_string(),
870 summary: this.summary_or_default(),
871 updated_at: this.updated_at(),
872 messages: this
873 .messages()
874 .map(|message| SerializedMessage {
875 id: message.id,
876 role: message.role,
877 segments: message
878 .segments
879 .iter()
880 .map(|segment| match segment {
881 MessageSegment::Text(text) => {
882 SerializedMessageSegment::Text { text: text.clone() }
883 }
884 MessageSegment::Thinking(text) => {
885 SerializedMessageSegment::Thinking { text: text.clone() }
886 }
887 })
888 .collect(),
889 tool_uses: this
890 .tool_uses_for_message(message.id, cx)
891 .into_iter()
892 .map(|tool_use| SerializedToolUse {
893 id: tool_use.id,
894 name: tool_use.name,
895 input: tool_use.input,
896 })
897 .collect(),
898 tool_results: this
899 .tool_results_for_message(message.id)
900 .into_iter()
901 .map(|tool_result| SerializedToolResult {
902 tool_use_id: tool_result.tool_use_id.clone(),
903 is_error: tool_result.is_error,
904 content: tool_result.content.clone(),
905 })
906 .collect(),
907 context: message.context.clone(),
908 })
909 .collect(),
910 initial_project_snapshot,
911 cumulative_token_usage: this.cumulative_token_usage,
912 request_token_usage: this.request_token_usage.clone(),
913 detailed_summary_state: this.detailed_summary_state.clone(),
914 exceeded_window_error: this.exceeded_window_error.clone(),
915 })
916 })
917 }
918
919 pub fn send_to_model(
920 &mut self,
921 model: Arc<dyn LanguageModel>,
922 request_kind: RequestKind,
923 cx: &mut Context<Self>,
924 ) {
925 let mut request = self.to_completion_request(request_kind, cx);
926 if model.supports_tools() {
927 request.tools = {
928 let mut tools = Vec::new();
929 tools.extend(
930 self.tools()
931 .read(cx)
932 .enabled_tools(cx)
933 .into_iter()
934 .filter_map(|tool| {
935 // Skip tools that cannot be supported
936 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
937 Some(LanguageModelRequestTool {
938 name: tool.name(),
939 description: tool.description(),
940 input_schema,
941 })
942 }),
943 );
944
945 tools
946 };
947 }
948
949 self.stream_completion(request, model, cx);
950 }
951
952 pub fn used_tools_since_last_user_message(&self) -> bool {
953 for message in self.messages.iter().rev() {
954 if self.tool_use.message_has_tool_results(message.id) {
955 return true;
956 } else if message.role == Role::User {
957 return false;
958 }
959 }
960
961 false
962 }
963
964 pub fn to_completion_request(
965 &self,
966 request_kind: RequestKind,
967 cx: &mut Context<Self>,
968 ) -> LanguageModelRequest {
969 let mut request = LanguageModelRequest {
970 thread_id: Some(self.id.to_string()),
971 prompt_id: Some(self.last_prompt_id.to_string()),
972 messages: vec![],
973 tools: Vec::new(),
974 stop: Vec::new(),
975 temperature: None,
976 };
977
978 if let Some(project_context) = self.project_context.borrow().as_ref() {
979 match self
980 .prompt_builder
981 .generate_assistant_system_prompt(project_context)
982 {
983 Err(err) => {
984 let message = format!("{err:?}").into();
985 log::error!("{message}");
986 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
987 header: "Error generating system prompt".into(),
988 message,
989 }));
990 }
991 Ok(system_prompt) => {
992 request.messages.push(LanguageModelRequestMessage {
993 role: Role::System,
994 content: vec![MessageContent::Text(system_prompt)],
995 cache: true,
996 });
997 }
998 }
999 } else {
1000 let message = "Context for system prompt unexpectedly not ready.".into();
1001 log::error!("{message}");
1002 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1003 header: "Error generating system prompt".into(),
1004 message,
1005 }));
1006 }
1007
1008 for message in &self.messages {
1009 let mut request_message = LanguageModelRequestMessage {
1010 role: message.role,
1011 content: Vec::new(),
1012 cache: false,
1013 };
1014
1015 match request_kind {
1016 RequestKind::Chat => {
1017 self.tool_use
1018 .attach_tool_results(message.id, &mut request_message);
1019 }
1020 RequestKind::Summarize => {
1021 // We don't care about tool use during summarization.
1022 if self.tool_use.message_has_tool_results(message.id) {
1023 continue;
1024 }
1025 }
1026 }
1027
1028 if !message.segments.is_empty() {
1029 request_message
1030 .content
1031 .push(MessageContent::Text(message.to_string()));
1032 }
1033
1034 match request_kind {
1035 RequestKind::Chat => {
1036 self.tool_use
1037 .attach_tool_uses(message.id, &mut request_message);
1038 }
1039 RequestKind::Summarize => {
1040 // We don't care about tool use during summarization.
1041 }
1042 };
1043
1044 request.messages.push(request_message);
1045 }
1046
1047 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1048 if let Some(last) = request.messages.last_mut() {
1049 last.cache = true;
1050 }
1051
1052 self.attached_tracked_files_state(&mut request.messages, cx);
1053
1054 request
1055 }
1056
1057 fn attached_tracked_files_state(
1058 &self,
1059 messages: &mut Vec<LanguageModelRequestMessage>,
1060 cx: &App,
1061 ) {
1062 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1063
1064 let mut stale_message = String::new();
1065
1066 let action_log = self.action_log.read(cx);
1067
1068 for stale_file in action_log.stale_buffers(cx) {
1069 let Some(file) = stale_file.read(cx).file() else {
1070 continue;
1071 };
1072
1073 if stale_message.is_empty() {
1074 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1075 }
1076
1077 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1078 }
1079
1080 let mut content = Vec::with_capacity(2);
1081
1082 if !stale_message.is_empty() {
1083 content.push(stale_message.into());
1084 }
1085
1086 if action_log.has_edited_files_since_project_diagnostics_check() {
1087 content.push(
1088 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1089 and fix all errors AND warnings you introduced! \
1090 DO NOT mention you're going to do this until you're done."
1091 .into(),
1092 );
1093 }
1094
1095 if !content.is_empty() {
1096 let context_message = LanguageModelRequestMessage {
1097 role: Role::User,
1098 content,
1099 cache: false,
1100 };
1101
1102 messages.push(context_message);
1103 }
1104 }
1105
1106 pub fn stream_completion(
1107 &mut self,
1108 request: LanguageModelRequest,
1109 model: Arc<dyn LanguageModel>,
1110 cx: &mut Context<Self>,
1111 ) {
1112 let pending_completion_id = post_inc(&mut self.completion_count);
1113 let prompt_id = self.last_prompt_id.clone();
1114 let task = cx.spawn(async move |thread, cx| {
1115 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1116 let initial_token_usage =
1117 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1118 let stream_completion = async {
1119 let (mut events, usage) = stream_completion_future.await?;
1120 let mut stop_reason = StopReason::EndTurn;
1121 let mut current_token_usage = TokenUsage::default();
1122
1123 if let Some(usage) = usage {
1124 thread
1125 .update(cx, |_thread, cx| {
1126 cx.emit(ThreadEvent::UsageUpdated(usage));
1127 })
1128 .ok();
1129 }
1130
1131 while let Some(event) = events.next().await {
1132 let event = event?;
1133
1134 thread.update(cx, |thread, cx| {
1135 match event {
1136 LanguageModelCompletionEvent::StartMessage { .. } => {
1137 thread.insert_message(
1138 Role::Assistant,
1139 vec![MessageSegment::Text(String::new())],
1140 cx,
1141 );
1142 }
1143 LanguageModelCompletionEvent::Stop(reason) => {
1144 stop_reason = reason;
1145 }
1146 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1147 thread.update_token_usage_at_last_message(token_usage);
1148 thread.cumulative_token_usage = thread.cumulative_token_usage
1149 + token_usage
1150 - current_token_usage;
1151 current_token_usage = token_usage;
1152 }
1153 LanguageModelCompletionEvent::Text(chunk) => {
1154 if let Some(last_message) = thread.messages.last_mut() {
1155 if last_message.role == Role::Assistant {
1156 last_message.push_text(&chunk);
1157 cx.emit(ThreadEvent::StreamedAssistantText(
1158 last_message.id,
1159 chunk,
1160 ));
1161 } else {
1162 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1163 // of a new Assistant response.
1164 //
1165 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1166 // will result in duplicating the text of the chunk in the rendered Markdown.
1167 thread.insert_message(
1168 Role::Assistant,
1169 vec![MessageSegment::Text(chunk.to_string())],
1170 cx,
1171 );
1172 };
1173 }
1174 }
1175 LanguageModelCompletionEvent::Thinking(chunk) => {
1176 if let Some(last_message) = thread.messages.last_mut() {
1177 if last_message.role == Role::Assistant {
1178 last_message.push_thinking(&chunk);
1179 cx.emit(ThreadEvent::StreamedAssistantThinking(
1180 last_message.id,
1181 chunk,
1182 ));
1183 } else {
1184 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1185 // of a new Assistant response.
1186 //
1187 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1188 // will result in duplicating the text of the chunk in the rendered Markdown.
1189 thread.insert_message(
1190 Role::Assistant,
1191 vec![MessageSegment::Thinking(chunk.to_string())],
1192 cx,
1193 );
1194 };
1195 }
1196 }
1197 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1198 let last_assistant_message_id = thread
1199 .messages
1200 .iter_mut()
1201 .rfind(|message| message.role == Role::Assistant)
1202 .map(|message| message.id)
1203 .unwrap_or_else(|| {
1204 thread.insert_message(Role::Assistant, vec![], cx)
1205 });
1206
1207 thread.tool_use.request_tool_use(
1208 last_assistant_message_id,
1209 tool_use,
1210 cx,
1211 );
1212 }
1213 }
1214
1215 thread.touch_updated_at();
1216 cx.emit(ThreadEvent::StreamedCompletion);
1217 cx.notify();
1218
1219 thread.auto_capture_telemetry(cx);
1220 })?;
1221
1222 smol::future::yield_now().await;
1223 }
1224
1225 thread.update(cx, |thread, cx| {
1226 thread
1227 .pending_completions
1228 .retain(|completion| completion.id != pending_completion_id);
1229
1230 if thread.summary.is_none() && thread.messages.len() >= 2 {
1231 thread.summarize(cx);
1232 }
1233 })?;
1234
1235 anyhow::Ok(stop_reason)
1236 };
1237
1238 let result = stream_completion.await;
1239
1240 thread
1241 .update(cx, |thread, cx| {
1242 thread.finalize_pending_checkpoint(cx);
1243 match result.as_ref() {
1244 Ok(stop_reason) => match stop_reason {
1245 StopReason::ToolUse => {
1246 let tool_uses = thread.use_pending_tools(cx);
1247 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1248 }
1249 StopReason::EndTurn => {}
1250 StopReason::MaxTokens => {}
1251 },
1252 Err(error) => {
1253 if error.is::<PaymentRequiredError>() {
1254 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1255 } else if error.is::<MaxMonthlySpendReachedError>() {
1256 cx.emit(ThreadEvent::ShowError(
1257 ThreadError::MaxMonthlySpendReached,
1258 ));
1259 } else if let Some(error) =
1260 error.downcast_ref::<ModelRequestLimitReachedError>()
1261 {
1262 cx.emit(ThreadEvent::ShowError(
1263 ThreadError::ModelRequestLimitReached { plan: error.plan },
1264 ));
1265 } else if let Some(known_error) =
1266 error.downcast_ref::<LanguageModelKnownError>()
1267 {
1268 match known_error {
1269 LanguageModelKnownError::ContextWindowLimitExceeded {
1270 tokens,
1271 } => {
1272 thread.exceeded_window_error = Some(ExceededWindowError {
1273 model_id: model.id(),
1274 token_count: *tokens,
1275 });
1276 cx.notify();
1277 }
1278 }
1279 } else {
1280 let error_message = error
1281 .chain()
1282 .map(|err| err.to_string())
1283 .collect::<Vec<_>>()
1284 .join("\n");
1285 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1286 header: "Error interacting with language model".into(),
1287 message: SharedString::from(error_message.clone()),
1288 }));
1289 }
1290
1291 thread.cancel_last_completion(cx);
1292 }
1293 }
1294 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1295
1296 thread.auto_capture_telemetry(cx);
1297
1298 if let Ok(initial_usage) = initial_token_usage {
1299 let usage = thread.cumulative_token_usage - initial_usage;
1300
1301 telemetry::event!(
1302 "Assistant Thread Completion",
1303 thread_id = thread.id().to_string(),
1304 prompt_id = prompt_id,
1305 model = model.telemetry_id(),
1306 model_provider = model.provider_id().to_string(),
1307 input_tokens = usage.input_tokens,
1308 output_tokens = usage.output_tokens,
1309 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1310 cache_read_input_tokens = usage.cache_read_input_tokens,
1311 );
1312 }
1313 })
1314 .ok();
1315 });
1316
1317 self.pending_completions.push(PendingCompletion {
1318 id: pending_completion_id,
1319 _task: task,
1320 });
1321 }
1322
1323 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1324 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1325 return;
1326 };
1327
1328 if !model.provider.is_authenticated(cx) {
1329 return;
1330 }
1331
1332 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1333 request.messages.push(LanguageModelRequestMessage {
1334 role: Role::User,
1335 content: vec![
1336 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1337 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1338 If the conversation is about a specific subject, include it in the title. \
1339 Be descriptive. DO NOT speak in the first person."
1340 .into(),
1341 ],
1342 cache: false,
1343 });
1344
1345 self.pending_summary = cx.spawn(async move |this, cx| {
1346 async move {
1347 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1348 let (mut messages, usage) = stream.await?;
1349
1350 if let Some(usage) = usage {
1351 this.update(cx, |_thread, cx| {
1352 cx.emit(ThreadEvent::UsageUpdated(usage));
1353 })
1354 .ok();
1355 }
1356
1357 let mut new_summary = String::new();
1358 while let Some(message) = messages.stream.next().await {
1359 let text = message?;
1360 let mut lines = text.lines();
1361 new_summary.extend(lines.next());
1362
1363 // Stop if the LLM generated multiple lines.
1364 if lines.next().is_some() {
1365 break;
1366 }
1367 }
1368
1369 this.update(cx, |this, cx| {
1370 if !new_summary.is_empty() {
1371 this.summary = Some(new_summary.into());
1372 }
1373
1374 cx.emit(ThreadEvent::SummaryGenerated);
1375 })?;
1376
1377 anyhow::Ok(())
1378 }
1379 .log_err()
1380 .await
1381 });
1382 }
1383
1384 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1385 let last_message_id = self.messages.last().map(|message| message.id)?;
1386
1387 match &self.detailed_summary_state {
1388 DetailedSummaryState::Generating { message_id, .. }
1389 | DetailedSummaryState::Generated { message_id, .. }
1390 if *message_id == last_message_id =>
1391 {
1392 // Already up-to-date
1393 return None;
1394 }
1395 _ => {}
1396 }
1397
1398 let ConfiguredModel { model, provider } =
1399 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1400
1401 if !provider.is_authenticated(cx) {
1402 return None;
1403 }
1404
1405 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1406
1407 request.messages.push(LanguageModelRequestMessage {
1408 role: Role::User,
1409 content: vec![
1410 "Generate a detailed summary of this conversation. Include:\n\
1411 1. A brief overview of what was discussed\n\
1412 2. Key facts or information discovered\n\
1413 3. Outcomes or conclusions reached\n\
1414 4. Any action items or next steps if any\n\
1415 Format it in Markdown with headings and bullet points."
1416 .into(),
1417 ],
1418 cache: false,
1419 });
1420
1421 let task = cx.spawn(async move |thread, cx| {
1422 let stream = model.stream_completion_text(request, &cx);
1423 let Some(mut messages) = stream.await.log_err() else {
1424 thread
1425 .update(cx, |this, _cx| {
1426 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1427 })
1428 .log_err();
1429
1430 return;
1431 };
1432
1433 let mut new_detailed_summary = String::new();
1434
1435 while let Some(chunk) = messages.stream.next().await {
1436 if let Some(chunk) = chunk.log_err() {
1437 new_detailed_summary.push_str(&chunk);
1438 }
1439 }
1440
1441 thread
1442 .update(cx, |this, _cx| {
1443 this.detailed_summary_state = DetailedSummaryState::Generated {
1444 text: new_detailed_summary.into(),
1445 message_id: last_message_id,
1446 };
1447 })
1448 .log_err();
1449 });
1450
1451 self.detailed_summary_state = DetailedSummaryState::Generating {
1452 message_id: last_message_id,
1453 };
1454
1455 Some(task)
1456 }
1457
1458 pub fn is_generating_detailed_summary(&self) -> bool {
1459 matches!(
1460 self.detailed_summary_state,
1461 DetailedSummaryState::Generating { .. }
1462 )
1463 }
1464
1465 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1466 self.auto_capture_telemetry(cx);
1467 let request = self.to_completion_request(RequestKind::Chat, cx);
1468 let messages = Arc::new(request.messages);
1469 let pending_tool_uses = self
1470 .tool_use
1471 .pending_tool_uses()
1472 .into_iter()
1473 .filter(|tool_use| tool_use.status.is_idle())
1474 .cloned()
1475 .collect::<Vec<_>>();
1476
1477 for tool_use in pending_tool_uses.iter() {
1478 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1479 if tool.needs_confirmation(&tool_use.input, cx)
1480 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1481 {
1482 self.tool_use.confirm_tool_use(
1483 tool_use.id.clone(),
1484 tool_use.ui_text.clone(),
1485 tool_use.input.clone(),
1486 messages.clone(),
1487 tool,
1488 );
1489 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1490 } else {
1491 self.run_tool(
1492 tool_use.id.clone(),
1493 tool_use.ui_text.clone(),
1494 tool_use.input.clone(),
1495 &messages,
1496 tool,
1497 cx,
1498 );
1499 }
1500 }
1501 }
1502
1503 pending_tool_uses
1504 }
1505
1506 pub fn run_tool(
1507 &mut self,
1508 tool_use_id: LanguageModelToolUseId,
1509 ui_text: impl Into<SharedString>,
1510 input: serde_json::Value,
1511 messages: &[LanguageModelRequestMessage],
1512 tool: Arc<dyn Tool>,
1513 cx: &mut Context<Thread>,
1514 ) {
1515 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1516 self.tool_use
1517 .run_pending_tool(tool_use_id, ui_text.into(), task);
1518 }
1519
1520 fn spawn_tool_use(
1521 &mut self,
1522 tool_use_id: LanguageModelToolUseId,
1523 messages: &[LanguageModelRequestMessage],
1524 input: serde_json::Value,
1525 tool: Arc<dyn Tool>,
1526 cx: &mut Context<Thread>,
1527 ) -> Task<()> {
1528 let tool_name: Arc<str> = tool.name().into();
1529
1530 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1531 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1532 } else {
1533 tool.run(
1534 input,
1535 messages,
1536 self.project.clone(),
1537 self.action_log.clone(),
1538 cx,
1539 )
1540 };
1541
1542 // Store the card separately if it exists
1543 if let Some(card) = tool_result.card.clone() {
1544 self.tool_use
1545 .insert_tool_result_card(tool_use_id.clone(), card);
1546 }
1547
1548 cx.spawn({
1549 async move |thread: WeakEntity<Thread>, cx| {
1550 let output = tool_result.output.await;
1551
1552 thread
1553 .update(cx, |thread, cx| {
1554 let pending_tool_use = thread.tool_use.insert_tool_output(
1555 tool_use_id.clone(),
1556 tool_name,
1557 output,
1558 cx,
1559 );
1560 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1561 })
1562 .ok();
1563 }
1564 })
1565 }
1566
1567 fn tool_finished(
1568 &mut self,
1569 tool_use_id: LanguageModelToolUseId,
1570 pending_tool_use: Option<PendingToolUse>,
1571 canceled: bool,
1572 cx: &mut Context<Self>,
1573 ) {
1574 if self.all_tools_finished() {
1575 let model_registry = LanguageModelRegistry::read_global(cx);
1576 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1577 self.attach_tool_results(cx);
1578 if !canceled {
1579 self.send_to_model(model, RequestKind::Chat, cx);
1580 }
1581 }
1582 }
1583
1584 cx.emit(ThreadEvent::ToolFinished {
1585 tool_use_id,
1586 pending_tool_use,
1587 });
1588 }
1589
1590 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1591 // Insert a user message to contain the tool results.
1592 self.insert_user_message(
1593 // TODO: Sending up a user message without any content results in the model sending back
1594 // responses that also don't have any content. We currently don't handle this case well,
1595 // so for now we provide some text to keep the model on track.
1596 "Here are the tool results.",
1597 Vec::new(),
1598 None,
1599 cx,
1600 );
1601 }
1602
1603 /// Cancels the last pending completion, if there are any pending.
1604 ///
1605 /// Returns whether a completion was canceled.
1606 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1607 let canceled = if self.pending_completions.pop().is_some() {
1608 true
1609 } else {
1610 let mut canceled = false;
1611 for pending_tool_use in self.tool_use.cancel_pending() {
1612 canceled = true;
1613 self.tool_finished(
1614 pending_tool_use.id.clone(),
1615 Some(pending_tool_use),
1616 true,
1617 cx,
1618 );
1619 }
1620 canceled
1621 };
1622 self.finalize_pending_checkpoint(cx);
1623 canceled
1624 }
1625
1626 pub fn feedback(&self) -> Option<ThreadFeedback> {
1627 self.feedback
1628 }
1629
1630 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1631 self.message_feedback.get(&message_id).copied()
1632 }
1633
1634 pub fn report_message_feedback(
1635 &mut self,
1636 message_id: MessageId,
1637 feedback: ThreadFeedback,
1638 cx: &mut Context<Self>,
1639 ) -> Task<Result<()>> {
1640 if self.message_feedback.get(&message_id) == Some(&feedback) {
1641 return Task::ready(Ok(()));
1642 }
1643
1644 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1645 let serialized_thread = self.serialize(cx);
1646 let thread_id = self.id().clone();
1647 let client = self.project.read(cx).client();
1648
1649 let enabled_tool_names: Vec<String> = self
1650 .tools()
1651 .read(cx)
1652 .enabled_tools(cx)
1653 .iter()
1654 .map(|tool| tool.name().to_string())
1655 .collect();
1656
1657 self.message_feedback.insert(message_id, feedback);
1658
1659 cx.notify();
1660
1661 let message_content = self
1662 .message(message_id)
1663 .map(|msg| msg.to_string())
1664 .unwrap_or_default();
1665
1666 cx.background_spawn(async move {
1667 let final_project_snapshot = final_project_snapshot.await;
1668 let serialized_thread = serialized_thread.await?;
1669 let thread_data =
1670 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1671
1672 let rating = match feedback {
1673 ThreadFeedback::Positive => "positive",
1674 ThreadFeedback::Negative => "negative",
1675 };
1676 telemetry::event!(
1677 "Assistant Thread Rated",
1678 rating,
1679 thread_id,
1680 enabled_tool_names,
1681 message_id = message_id.0,
1682 message_content,
1683 thread_data,
1684 final_project_snapshot
1685 );
1686 client.telemetry().flush_events();
1687
1688 Ok(())
1689 })
1690 }
1691
1692 pub fn report_feedback(
1693 &mut self,
1694 feedback: ThreadFeedback,
1695 cx: &mut Context<Self>,
1696 ) -> Task<Result<()>> {
1697 let last_assistant_message_id = self
1698 .messages
1699 .iter()
1700 .rev()
1701 .find(|msg| msg.role == Role::Assistant)
1702 .map(|msg| msg.id);
1703
1704 if let Some(message_id) = last_assistant_message_id {
1705 self.report_message_feedback(message_id, feedback, cx)
1706 } else {
1707 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1708 let serialized_thread = self.serialize(cx);
1709 let thread_id = self.id().clone();
1710 let client = self.project.read(cx).client();
1711 self.feedback = Some(feedback);
1712 cx.notify();
1713
1714 cx.background_spawn(async move {
1715 let final_project_snapshot = final_project_snapshot.await;
1716 let serialized_thread = serialized_thread.await?;
1717 let thread_data = serde_json::to_value(serialized_thread)
1718 .unwrap_or_else(|_| serde_json::Value::Null);
1719
1720 let rating = match feedback {
1721 ThreadFeedback::Positive => "positive",
1722 ThreadFeedback::Negative => "negative",
1723 };
1724 telemetry::event!(
1725 "Assistant Thread Rated",
1726 rating,
1727 thread_id,
1728 thread_data,
1729 final_project_snapshot
1730 );
1731 client.telemetry().flush_events();
1732
1733 Ok(())
1734 })
1735 }
1736 }
1737
1738 /// Create a snapshot of the current project state including git information and unsaved buffers.
1739 fn project_snapshot(
1740 project: Entity<Project>,
1741 cx: &mut Context<Self>,
1742 ) -> Task<Arc<ProjectSnapshot>> {
1743 let git_store = project.read(cx).git_store().clone();
1744 let worktree_snapshots: Vec<_> = project
1745 .read(cx)
1746 .visible_worktrees(cx)
1747 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1748 .collect();
1749
1750 cx.spawn(async move |_, cx| {
1751 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1752
1753 let mut unsaved_buffers = Vec::new();
1754 cx.update(|app_cx| {
1755 let buffer_store = project.read(app_cx).buffer_store();
1756 for buffer_handle in buffer_store.read(app_cx).buffers() {
1757 let buffer = buffer_handle.read(app_cx);
1758 if buffer.is_dirty() {
1759 if let Some(file) = buffer.file() {
1760 let path = file.path().to_string_lossy().to_string();
1761 unsaved_buffers.push(path);
1762 }
1763 }
1764 }
1765 })
1766 .ok();
1767
1768 Arc::new(ProjectSnapshot {
1769 worktree_snapshots,
1770 unsaved_buffer_paths: unsaved_buffers,
1771 timestamp: Utc::now(),
1772 })
1773 })
1774 }
1775
1776 fn worktree_snapshot(
1777 worktree: Entity<project::Worktree>,
1778 git_store: Entity<GitStore>,
1779 cx: &App,
1780 ) -> Task<WorktreeSnapshot> {
1781 cx.spawn(async move |cx| {
1782 // Get worktree path and snapshot
1783 let worktree_info = cx.update(|app_cx| {
1784 let worktree = worktree.read(app_cx);
1785 let path = worktree.abs_path().to_string_lossy().to_string();
1786 let snapshot = worktree.snapshot();
1787 (path, snapshot)
1788 });
1789
1790 let Ok((worktree_path, _snapshot)) = worktree_info else {
1791 return WorktreeSnapshot {
1792 worktree_path: String::new(),
1793 git_state: None,
1794 };
1795 };
1796
1797 let git_state = git_store
1798 .update(cx, |git_store, cx| {
1799 git_store
1800 .repositories()
1801 .values()
1802 .find(|repo| {
1803 repo.read(cx)
1804 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1805 .is_some()
1806 })
1807 .cloned()
1808 })
1809 .ok()
1810 .flatten()
1811 .map(|repo| {
1812 repo.update(cx, |repo, _| {
1813 let current_branch =
1814 repo.branch.as_ref().map(|branch| branch.name.to_string());
1815 repo.send_job(None, |state, _| async move {
1816 let RepositoryState::Local { backend, .. } = state else {
1817 return GitState {
1818 remote_url: None,
1819 head_sha: None,
1820 current_branch,
1821 diff: None,
1822 };
1823 };
1824
1825 let remote_url = backend.remote_url("origin");
1826 let head_sha = backend.head_sha();
1827 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1828
1829 GitState {
1830 remote_url,
1831 head_sha,
1832 current_branch,
1833 diff,
1834 }
1835 })
1836 })
1837 });
1838
1839 let git_state = match git_state {
1840 Some(git_state) => match git_state.ok() {
1841 Some(git_state) => git_state.await.ok(),
1842 None => None,
1843 },
1844 None => None,
1845 };
1846
1847 WorktreeSnapshot {
1848 worktree_path,
1849 git_state,
1850 }
1851 })
1852 }
1853
1854 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1855 let mut markdown = Vec::new();
1856
1857 if let Some(summary) = self.summary() {
1858 writeln!(markdown, "# {summary}\n")?;
1859 };
1860
1861 for message in self.messages() {
1862 writeln!(
1863 markdown,
1864 "## {role}\n",
1865 role = match message.role {
1866 Role::User => "User",
1867 Role::Assistant => "Assistant",
1868 Role::System => "System",
1869 }
1870 )?;
1871
1872 if !message.context.is_empty() {
1873 writeln!(markdown, "{}", message.context)?;
1874 }
1875
1876 for segment in &message.segments {
1877 match segment {
1878 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1879 MessageSegment::Thinking(text) => {
1880 writeln!(markdown, "<think>{}</think>\n", text)?
1881 }
1882 }
1883 }
1884
1885 for tool_use in self.tool_uses_for_message(message.id, cx) {
1886 writeln!(
1887 markdown,
1888 "**Use Tool: {} ({})**",
1889 tool_use.name, tool_use.id
1890 )?;
1891 writeln!(markdown, "```json")?;
1892 writeln!(
1893 markdown,
1894 "{}",
1895 serde_json::to_string_pretty(&tool_use.input)?
1896 )?;
1897 writeln!(markdown, "```")?;
1898 }
1899
1900 for tool_result in self.tool_results_for_message(message.id) {
1901 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1902 if tool_result.is_error {
1903 write!(markdown, " (Error)")?;
1904 }
1905
1906 writeln!(markdown, "**\n")?;
1907 writeln!(markdown, "{}", tool_result.content)?;
1908 }
1909 }
1910
1911 Ok(String::from_utf8_lossy(&markdown).to_string())
1912 }
1913
1914 pub fn keep_edits_in_range(
1915 &mut self,
1916 buffer: Entity<language::Buffer>,
1917 buffer_range: Range<language::Anchor>,
1918 cx: &mut Context<Self>,
1919 ) {
1920 self.action_log.update(cx, |action_log, cx| {
1921 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1922 });
1923 }
1924
1925 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1926 self.action_log
1927 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1928 }
1929
1930 pub fn reject_edits_in_ranges(
1931 &mut self,
1932 buffer: Entity<language::Buffer>,
1933 buffer_ranges: Vec<Range<language::Anchor>>,
1934 cx: &mut Context<Self>,
1935 ) -> Task<Result<()>> {
1936 self.action_log.update(cx, |action_log, cx| {
1937 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1938 })
1939 }
1940
1941 pub fn action_log(&self) -> &Entity<ActionLog> {
1942 &self.action_log
1943 }
1944
1945 pub fn project(&self) -> &Entity<Project> {
1946 &self.project
1947 }
1948
1949 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1950 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1951 return;
1952 }
1953
1954 let now = Instant::now();
1955 if let Some(last) = self.last_auto_capture_at {
1956 if now.duration_since(last).as_secs() < 10 {
1957 return;
1958 }
1959 }
1960
1961 self.last_auto_capture_at = Some(now);
1962
1963 let thread_id = self.id().clone();
1964 let github_login = self
1965 .project
1966 .read(cx)
1967 .user_store()
1968 .read(cx)
1969 .current_user()
1970 .map(|user| user.github_login.clone());
1971 let client = self.project.read(cx).client().clone();
1972 let serialize_task = self.serialize(cx);
1973
1974 cx.background_executor()
1975 .spawn(async move {
1976 if let Ok(serialized_thread) = serialize_task.await {
1977 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1978 telemetry::event!(
1979 "Agent Thread Auto-Captured",
1980 thread_id = thread_id.to_string(),
1981 thread_data = thread_data,
1982 auto_capture_reason = "tracked_user",
1983 github_login = github_login
1984 );
1985
1986 client.telemetry().flush_events();
1987 }
1988 }
1989 })
1990 .detach();
1991 }
1992
1993 pub fn cumulative_token_usage(&self) -> TokenUsage {
1994 self.cumulative_token_usage
1995 }
1996
1997 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1998 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1999 return TotalTokenUsage::default();
2000 };
2001
2002 let max = model.model.max_token_count();
2003
2004 let index = self
2005 .messages
2006 .iter()
2007 .position(|msg| msg.id == message_id)
2008 .unwrap_or(0);
2009
2010 if index == 0 {
2011 return TotalTokenUsage { total: 0, max };
2012 }
2013
2014 let token_usage = &self
2015 .request_token_usage
2016 .get(index - 1)
2017 .cloned()
2018 .unwrap_or_default();
2019
2020 TotalTokenUsage {
2021 total: token_usage.total_tokens() as usize,
2022 max,
2023 }
2024 }
2025
2026 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2027 let model_registry = LanguageModelRegistry::read_global(cx);
2028 let Some(model) = model_registry.default_model() else {
2029 return TotalTokenUsage::default();
2030 };
2031
2032 let max = model.model.max_token_count();
2033
2034 if let Some(exceeded_error) = &self.exceeded_window_error {
2035 if model.model.id() == exceeded_error.model_id {
2036 return TotalTokenUsage {
2037 total: exceeded_error.token_count,
2038 max,
2039 };
2040 }
2041 }
2042
2043 let total = self
2044 .token_usage_at_last_message()
2045 .unwrap_or_default()
2046 .total_tokens() as usize;
2047
2048 TotalTokenUsage { total, max }
2049 }
2050
2051 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2052 self.request_token_usage
2053 .get(self.messages.len().saturating_sub(1))
2054 .or_else(|| self.request_token_usage.last())
2055 .cloned()
2056 }
2057
2058 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2059 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2060 self.request_token_usage
2061 .resize(self.messages.len(), placeholder);
2062
2063 if let Some(last) = self.request_token_usage.last_mut() {
2064 *last = token_usage;
2065 }
2066 }
2067
2068 pub fn deny_tool_use(
2069 &mut self,
2070 tool_use_id: LanguageModelToolUseId,
2071 tool_name: Arc<str>,
2072 cx: &mut Context<Self>,
2073 ) {
2074 let err = Err(anyhow::anyhow!(
2075 "Permission to run tool action denied by user"
2076 ));
2077
2078 self.tool_use
2079 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2080 self.tool_finished(tool_use_id.clone(), None, true, cx);
2081 }
2082}
2083
2084#[derive(Debug, Clone, Error)]
2085pub enum ThreadError {
2086 #[error("Payment required")]
2087 PaymentRequired,
2088 #[error("Max monthly spend reached")]
2089 MaxMonthlySpendReached,
2090 #[error("Model request limit reached")]
2091 ModelRequestLimitReached { plan: Plan },
2092 #[error("Message {header}: {message}")]
2093 Message {
2094 header: SharedString,
2095 message: SharedString,
2096 },
2097}
2098
2099#[derive(Debug, Clone)]
2100pub enum ThreadEvent {
2101 ShowError(ThreadError),
2102 UsageUpdated(RequestUsage),
2103 StreamedCompletion,
2104 StreamedAssistantText(MessageId, String),
2105 StreamedAssistantThinking(MessageId, String),
2106 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2107 MessageAdded(MessageId),
2108 MessageEdited(MessageId),
2109 MessageDeleted(MessageId),
2110 SummaryGenerated,
2111 SummaryChanged,
2112 UsePendingTools {
2113 tool_uses: Vec<PendingToolUse>,
2114 },
2115 ToolFinished {
2116 #[allow(unused)]
2117 tool_use_id: LanguageModelToolUseId,
2118 /// The pending tool use that corresponds to this tool.
2119 pending_tool_use: Option<PendingToolUse>,
2120 },
2121 CheckpointChanged,
2122 ToolConfirmationNeeded,
2123}
2124
2125impl EventEmitter<ThreadEvent> for Thread {}
2126
2127struct PendingCompletion {
2128 id: usize,
2129 _task: Task<()>,
2130}
2131
2132#[cfg(test)]
2133mod tests {
2134 use super::*;
2135 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2136 use assistant_settings::AssistantSettings;
2137 use context_server::ContextServerSettings;
2138 use editor::EditorSettings;
2139 use gpui::TestAppContext;
2140 use project::{FakeFs, Project};
2141 use prompt_store::PromptBuilder;
2142 use serde_json::json;
2143 use settings::{Settings, SettingsStore};
2144 use std::sync::Arc;
2145 use theme::ThemeSettings;
2146 use util::path;
2147 use workspace::Workspace;
2148
2149 #[gpui::test]
2150 async fn test_message_with_context(cx: &mut TestAppContext) {
2151 init_test_settings(cx);
2152
2153 let project = create_test_project(
2154 cx,
2155 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2156 )
2157 .await;
2158
2159 let (_workspace, _thread_store, thread, context_store) =
2160 setup_test_environment(cx, project.clone()).await;
2161
2162 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2163 .await
2164 .unwrap();
2165
2166 let context =
2167 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2168
2169 // Insert user message with context
2170 let message_id = thread.update(cx, |thread, cx| {
2171 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2172 });
2173
2174 // Check content and context in message object
2175 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2176
2177 // Use different path format strings based on platform for the test
2178 #[cfg(windows)]
2179 let path_part = r"test\code.rs";
2180 #[cfg(not(windows))]
2181 let path_part = "test/code.rs";
2182
2183 let expected_context = format!(
2184 r#"
2185<context>
2186The following items were attached by the user. You don't need to use other tools to read them.
2187
2188<files>
2189```rs {path_part}
2190fn main() {{
2191 println!("Hello, world!");
2192}}
2193```
2194</files>
2195</context>
2196"#
2197 );
2198
2199 assert_eq!(message.role, Role::User);
2200 assert_eq!(message.segments.len(), 1);
2201 assert_eq!(
2202 message.segments[0],
2203 MessageSegment::Text("Please explain this code".to_string())
2204 );
2205 assert_eq!(message.context, expected_context);
2206
2207 // Check message in request
2208 let request = thread.update(cx, |thread, cx| {
2209 thread.to_completion_request(RequestKind::Chat, cx)
2210 });
2211
2212 assert_eq!(request.messages.len(), 2);
2213 let expected_full_message = format!("{}Please explain this code", expected_context);
2214 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2215 }
2216
2217 #[gpui::test]
2218 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2219 init_test_settings(cx);
2220
2221 let project = create_test_project(
2222 cx,
2223 json!({
2224 "file1.rs": "fn function1() {}\n",
2225 "file2.rs": "fn function2() {}\n",
2226 "file3.rs": "fn function3() {}\n",
2227 }),
2228 )
2229 .await;
2230
2231 let (_, _thread_store, thread, context_store) =
2232 setup_test_environment(cx, project.clone()).await;
2233
2234 // Open files individually
2235 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2236 .await
2237 .unwrap();
2238 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2239 .await
2240 .unwrap();
2241 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2242 .await
2243 .unwrap();
2244
2245 // Get the context objects
2246 let contexts = context_store.update(cx, |store, _| store.context().clone());
2247 assert_eq!(contexts.len(), 3);
2248
2249 // First message with context 1
2250 let message1_id = thread.update(cx, |thread, cx| {
2251 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2252 });
2253
2254 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2255 let message2_id = thread.update(cx, |thread, cx| {
2256 thread.insert_user_message(
2257 "Message 2",
2258 vec![contexts[0].clone(), contexts[1].clone()],
2259 None,
2260 cx,
2261 )
2262 });
2263
2264 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2265 let message3_id = thread.update(cx, |thread, cx| {
2266 thread.insert_user_message(
2267 "Message 3",
2268 vec![
2269 contexts[0].clone(),
2270 contexts[1].clone(),
2271 contexts[2].clone(),
2272 ],
2273 None,
2274 cx,
2275 )
2276 });
2277
2278 // Check what contexts are included in each message
2279 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2280 (
2281 thread.message(message1_id).unwrap().clone(),
2282 thread.message(message2_id).unwrap().clone(),
2283 thread.message(message3_id).unwrap().clone(),
2284 )
2285 });
2286
2287 // First message should include context 1
2288 assert!(message1.context.contains("file1.rs"));
2289
2290 // Second message should include only context 2 (not 1)
2291 assert!(!message2.context.contains("file1.rs"));
2292 assert!(message2.context.contains("file2.rs"));
2293
2294 // Third message should include only context 3 (not 1 or 2)
2295 assert!(!message3.context.contains("file1.rs"));
2296 assert!(!message3.context.contains("file2.rs"));
2297 assert!(message3.context.contains("file3.rs"));
2298
2299 // Check entire request to make sure all contexts are properly included
2300 let request = thread.update(cx, |thread, cx| {
2301 thread.to_completion_request(RequestKind::Chat, cx)
2302 });
2303
2304 // The request should contain all 3 messages
2305 assert_eq!(request.messages.len(), 4);
2306
2307 // Check that the contexts are properly formatted in each message
2308 assert!(request.messages[1].string_contents().contains("file1.rs"));
2309 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2310 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2311
2312 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2313 assert!(request.messages[2].string_contents().contains("file2.rs"));
2314 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2315
2316 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2317 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2318 assert!(request.messages[3].string_contents().contains("file3.rs"));
2319 }
2320
2321 #[gpui::test]
2322 async fn test_message_without_files(cx: &mut TestAppContext) {
2323 init_test_settings(cx);
2324
2325 let project = create_test_project(
2326 cx,
2327 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2328 )
2329 .await;
2330
2331 let (_, _thread_store, thread, _context_store) =
2332 setup_test_environment(cx, project.clone()).await;
2333
2334 // Insert user message without any context (empty context vector)
2335 let message_id = thread.update(cx, |thread, cx| {
2336 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2337 });
2338
2339 // Check content and context in message object
2340 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2341
2342 // Context should be empty when no files are included
2343 assert_eq!(message.role, Role::User);
2344 assert_eq!(message.segments.len(), 1);
2345 assert_eq!(
2346 message.segments[0],
2347 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2348 );
2349 assert_eq!(message.context, "");
2350
2351 // Check message in request
2352 let request = thread.update(cx, |thread, cx| {
2353 thread.to_completion_request(RequestKind::Chat, cx)
2354 });
2355
2356 assert_eq!(request.messages.len(), 2);
2357 assert_eq!(
2358 request.messages[1].string_contents(),
2359 "What is the best way to learn Rust?"
2360 );
2361
2362 // Add second message, also without context
2363 let message2_id = thread.update(cx, |thread, cx| {
2364 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2365 });
2366
2367 let message2 =
2368 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2369 assert_eq!(message2.context, "");
2370
2371 // Check that both messages appear in the request
2372 let request = thread.update(cx, |thread, cx| {
2373 thread.to_completion_request(RequestKind::Chat, cx)
2374 });
2375
2376 assert_eq!(request.messages.len(), 3);
2377 assert_eq!(
2378 request.messages[1].string_contents(),
2379 "What is the best way to learn Rust?"
2380 );
2381 assert_eq!(
2382 request.messages[2].string_contents(),
2383 "Are there any good books?"
2384 );
2385 }
2386
2387 #[gpui::test]
2388 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2389 init_test_settings(cx);
2390
2391 let project = create_test_project(
2392 cx,
2393 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2394 )
2395 .await;
2396
2397 let (_workspace, _thread_store, thread, context_store) =
2398 setup_test_environment(cx, project.clone()).await;
2399
2400 // Open buffer and add it to context
2401 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2402 .await
2403 .unwrap();
2404
2405 let context =
2406 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2407
2408 // Insert user message with the buffer as context
2409 thread.update(cx, |thread, cx| {
2410 thread.insert_user_message("Explain this code", vec![context], None, cx)
2411 });
2412
2413 // Create a request and check that it doesn't have a stale buffer warning yet
2414 let initial_request = thread.update(cx, |thread, cx| {
2415 thread.to_completion_request(RequestKind::Chat, cx)
2416 });
2417
2418 // Make sure we don't have a stale file warning yet
2419 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2420 msg.string_contents()
2421 .contains("These files changed since last read:")
2422 });
2423 assert!(
2424 !has_stale_warning,
2425 "Should not have stale buffer warning before buffer is modified"
2426 );
2427
2428 // Modify the buffer
2429 buffer.update(cx, |buffer, cx| {
2430 // Find a position at the end of line 1
2431 buffer.edit(
2432 [(1..1, "\n println!(\"Added a new line\");\n")],
2433 None,
2434 cx,
2435 );
2436 });
2437
2438 // Insert another user message without context
2439 thread.update(cx, |thread, cx| {
2440 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2441 });
2442
2443 // Create a new request and check for the stale buffer warning
2444 let new_request = thread.update(cx, |thread, cx| {
2445 thread.to_completion_request(RequestKind::Chat, cx)
2446 });
2447
2448 // We should have a stale file warning as the last message
2449 let last_message = new_request
2450 .messages
2451 .last()
2452 .expect("Request should have messages");
2453
2454 // The last message should be the stale buffer notification
2455 assert_eq!(last_message.role, Role::User);
2456
2457 // Check the exact content of the message
2458 let expected_content = "These files changed since last read:\n- code.rs\n";
2459 assert_eq!(
2460 last_message.string_contents(),
2461 expected_content,
2462 "Last message should be exactly the stale buffer notification"
2463 );
2464 }
2465
2466 fn init_test_settings(cx: &mut TestAppContext) {
2467 cx.update(|cx| {
2468 let settings_store = SettingsStore::test(cx);
2469 cx.set_global(settings_store);
2470 language::init(cx);
2471 Project::init_settings(cx);
2472 AssistantSettings::register(cx);
2473 prompt_store::init(cx);
2474 thread_store::init(cx);
2475 workspace::init_settings(cx);
2476 ThemeSettings::register(cx);
2477 ContextServerSettings::register(cx);
2478 EditorSettings::register(cx);
2479 });
2480 }
2481
2482 // Helper to create a test project with test files
2483 async fn create_test_project(
2484 cx: &mut TestAppContext,
2485 files: serde_json::Value,
2486 ) -> Entity<Project> {
2487 let fs = FakeFs::new(cx.executor());
2488 fs.insert_tree(path!("/test"), files).await;
2489 Project::test(fs, [path!("/test").as_ref()], cx).await
2490 }
2491
2492 async fn setup_test_environment(
2493 cx: &mut TestAppContext,
2494 project: Entity<Project>,
2495 ) -> (
2496 Entity<Workspace>,
2497 Entity<ThreadStore>,
2498 Entity<Thread>,
2499 Entity<ContextStore>,
2500 ) {
2501 let (workspace, cx) =
2502 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2503
2504 let thread_store = cx
2505 .update(|_, cx| {
2506 ThreadStore::load(
2507 project.clone(),
2508 cx.new(|_| ToolWorkingSet::default()),
2509 Arc::new(PromptBuilder::new(None).unwrap()),
2510 cx,
2511 )
2512 })
2513 .await
2514 .unwrap();
2515
2516 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2517 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2518
2519 (workspace, thread_store, thread, context_store)
2520 }
2521
2522 async fn add_file_to_context(
2523 project: &Entity<Project>,
2524 context_store: &Entity<ContextStore>,
2525 path: &str,
2526 cx: &mut TestAppContext,
2527 ) -> Result<Entity<language::Buffer>> {
2528 let buffer_path = project
2529 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2530 .unwrap();
2531
2532 let buffer = project
2533 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2534 .await
2535 .unwrap();
2536
2537 context_store
2538 .update(cx, |store, cx| {
2539 store.add_file_from_buffer(buffer.clone(), cx)
2540 })
2541 .await?;
2542
2543 Ok(buffer)
2544 }
2545}