1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// A message in a [`Thread`].
101#[derive(Debug, Clone)]
102pub struct Message {
103 pub id: MessageId,
104 pub role: Role,
105 pub segments: Vec<MessageSegment>,
106 pub context: String,
107}
108
109impl Message {
110 /// Returns whether the message contains any meaningful text that should be displayed
111 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
112 pub fn should_display_content(&self) -> bool {
113 self.segments.iter().all(|segment| segment.should_display())
114 }
115
116 pub fn push_thinking(&mut self, text: &str) {
117 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
118 segment.push_str(text);
119 } else {
120 self.segments
121 .push(MessageSegment::Thinking(text.to_string()));
122 }
123 }
124
125 pub fn push_text(&mut self, text: &str) {
126 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
127 segment.push_str(text);
128 } else {
129 self.segments.push(MessageSegment::Text(text.to_string()));
130 }
131 }
132
133 pub fn to_string(&self) -> String {
134 let mut result = String::new();
135
136 if !self.context.is_empty() {
137 result.push_str(&self.context);
138 }
139
140 for segment in &self.segments {
141 match segment {
142 MessageSegment::Text(text) => result.push_str(text),
143 MessageSegment::Thinking(text) => {
144 result.push_str("<think>");
145 result.push_str(text);
146 result.push_str("</think>");
147 }
148 }
149 }
150
151 result
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
156pub enum MessageSegment {
157 Text(String),
158 Thinking(String),
159}
160
161impl MessageSegment {
162 pub fn text_mut(&mut self) -> &mut String {
163 match self {
164 Self::Text(text) => text,
165 Self::Thinking(text) => text,
166 }
167 }
168
169 pub fn should_display(&self) -> bool {
170 // We add USING_TOOL_MARKER when making a request that includes tool uses
171 // without non-whitespace text around them, and this can cause the model
172 // to mimic the pattern, so we consider those segments not displayable.
173 match self {
174 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
175 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
176 }
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ProjectSnapshot {
182 pub worktree_snapshots: Vec<WorktreeSnapshot>,
183 pub unsaved_buffer_paths: Vec<String>,
184 pub timestamp: DateTime<Utc>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct WorktreeSnapshot {
189 pub worktree_path: String,
190 pub git_state: Option<GitState>,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct GitState {
195 pub remote_url: Option<String>,
196 pub head_sha: Option<String>,
197 pub current_branch: Option<String>,
198 pub diff: Option<String>,
199}
200
201#[derive(Clone)]
202pub struct ThreadCheckpoint {
203 message_id: MessageId,
204 git_checkpoint: GitStoreCheckpoint,
205}
206
207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
208pub enum ThreadFeedback {
209 Positive,
210 Negative,
211}
212
213pub enum LastRestoreCheckpoint {
214 Pending {
215 message_id: MessageId,
216 },
217 Error {
218 message_id: MessageId,
219 error: String,
220 },
221}
222
223impl LastRestoreCheckpoint {
224 pub fn message_id(&self) -> MessageId {
225 match self {
226 LastRestoreCheckpoint::Pending { message_id } => *message_id,
227 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
228 }
229 }
230}
231
232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
233pub enum DetailedSummaryState {
234 #[default]
235 NotGenerated,
236 Generating {
237 message_id: MessageId,
238 },
239 Generated {
240 text: SharedString,
241 message_id: MessageId,
242 },
243}
244
245#[derive(Default)]
246pub struct TotalTokenUsage {
247 pub total: usize,
248 pub max: usize,
249}
250
251impl TotalTokenUsage {
252 pub fn ratio(&self) -> TokenUsageRatio {
253 #[cfg(debug_assertions)]
254 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
255 .unwrap_or("0.8".to_string())
256 .parse()
257 .unwrap();
258 #[cfg(not(debug_assertions))]
259 let warning_threshold: f32 = 0.8;
260
261 if self.total >= self.max {
262 TokenUsageRatio::Exceeded
263 } else if self.total as f32 / self.max as f32 >= warning_threshold {
264 TokenUsageRatio::Warning
265 } else {
266 TokenUsageRatio::Normal
267 }
268 }
269
270 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
271 TotalTokenUsage {
272 total: self.total + tokens,
273 max: self.max,
274 }
275 }
276}
277
278#[derive(Debug, Default, PartialEq, Eq)]
279pub enum TokenUsageRatio {
280 #[default]
281 Normal,
282 Warning,
283 Exceeded,
284}
285
286/// A thread of conversation with the LLM.
287pub struct Thread {
288 id: ThreadId,
289 updated_at: DateTime<Utc>,
290 summary: Option<SharedString>,
291 pending_summary: Task<Option<()>>,
292 detailed_summary_state: DetailedSummaryState,
293 messages: Vec<Message>,
294 next_message_id: MessageId,
295 last_prompt_id: PromptId,
296 context: BTreeMap<ContextId, AssistantContext>,
297 context_by_message: HashMap<MessageId, Vec<ContextId>>,
298 project_context: SharedProjectContext,
299 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
300 completion_count: usize,
301 pending_completions: Vec<PendingCompletion>,
302 project: Entity<Project>,
303 prompt_builder: Arc<PromptBuilder>,
304 tools: Entity<ToolWorkingSet>,
305 tool_use: ToolUseState,
306 action_log: Entity<ActionLog>,
307 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
308 pending_checkpoint: Option<ThreadCheckpoint>,
309 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
310 request_token_usage: Vec<TokenUsage>,
311 cumulative_token_usage: TokenUsage,
312 exceeded_window_error: Option<ExceededWindowError>,
313 feedback: Option<ThreadFeedback>,
314 message_feedback: HashMap<MessageId, ThreadFeedback>,
315 last_auto_capture_at: Option<Instant>,
316 request_callback: Option<
317 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
318 >,
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct ExceededWindowError {
323 /// Model used when last message exceeded context window
324 model_id: LanguageModelId,
325 /// Token count including last message
326 token_count: usize,
327}
328
329impl Thread {
330 pub fn new(
331 project: Entity<Project>,
332 tools: Entity<ToolWorkingSet>,
333 prompt_builder: Arc<PromptBuilder>,
334 system_prompt: SharedProjectContext,
335 cx: &mut Context<Self>,
336 ) -> Self {
337 Self {
338 id: ThreadId::new(),
339 updated_at: Utc::now(),
340 summary: None,
341 pending_summary: Task::ready(None),
342 detailed_summary_state: DetailedSummaryState::NotGenerated,
343 messages: Vec::new(),
344 next_message_id: MessageId(0),
345 last_prompt_id: PromptId::new(),
346 context: BTreeMap::default(),
347 context_by_message: HashMap::default(),
348 project_context: system_prompt,
349 checkpoints_by_message: HashMap::default(),
350 completion_count: 0,
351 pending_completions: Vec::new(),
352 project: project.clone(),
353 prompt_builder,
354 tools: tools.clone(),
355 last_restore_checkpoint: None,
356 pending_checkpoint: None,
357 tool_use: ToolUseState::new(tools.clone()),
358 action_log: cx.new(|_| ActionLog::new(project.clone())),
359 initial_project_snapshot: {
360 let project_snapshot = Self::project_snapshot(project, cx);
361 cx.foreground_executor()
362 .spawn(async move { Some(project_snapshot.await) })
363 .shared()
364 },
365 request_token_usage: Vec::new(),
366 cumulative_token_usage: TokenUsage::default(),
367 exceeded_window_error: None,
368 feedback: None,
369 message_feedback: HashMap::default(),
370 last_auto_capture_at: None,
371 request_callback: None,
372 }
373 }
374
375 pub fn deserialize(
376 id: ThreadId,
377 serialized: SerializedThread,
378 project: Entity<Project>,
379 tools: Entity<ToolWorkingSet>,
380 prompt_builder: Arc<PromptBuilder>,
381 project_context: SharedProjectContext,
382 cx: &mut Context<Self>,
383 ) -> Self {
384 let next_message_id = MessageId(
385 serialized
386 .messages
387 .last()
388 .map(|message| message.id.0 + 1)
389 .unwrap_or(0),
390 );
391 let tool_use =
392 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
393
394 Self {
395 id,
396 updated_at: serialized.updated_at,
397 summary: Some(serialized.summary),
398 pending_summary: Task::ready(None),
399 detailed_summary_state: serialized.detailed_summary_state,
400 messages: serialized
401 .messages
402 .into_iter()
403 .map(|message| Message {
404 id: message.id,
405 role: message.role,
406 segments: message
407 .segments
408 .into_iter()
409 .map(|segment| match segment {
410 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
411 SerializedMessageSegment::Thinking { text } => {
412 MessageSegment::Thinking(text)
413 }
414 })
415 .collect(),
416 context: message.context,
417 })
418 .collect(),
419 next_message_id,
420 last_prompt_id: PromptId::new(),
421 context: BTreeMap::default(),
422 context_by_message: HashMap::default(),
423 project_context,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 last_restore_checkpoint: None,
428 pending_checkpoint: None,
429 project: project.clone(),
430 prompt_builder,
431 tools,
432 tool_use,
433 action_log: cx.new(|_| ActionLog::new(project)),
434 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
435 request_token_usage: serialized.request_token_usage,
436 cumulative_token_usage: serialized.cumulative_token_usage,
437 exceeded_window_error: None,
438 feedback: None,
439 message_feedback: HashMap::default(),
440 last_auto_capture_at: None,
441 request_callback: None,
442 }
443 }
444
445 pub fn set_request_callback(
446 &mut self,
447 callback: impl 'static
448 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
449 ) {
450 self.request_callback = Some(Box::new(callback));
451 }
452
453 pub fn id(&self) -> &ThreadId {
454 &self.id
455 }
456
457 pub fn is_empty(&self) -> bool {
458 self.messages.is_empty()
459 }
460
461 pub fn updated_at(&self) -> DateTime<Utc> {
462 self.updated_at
463 }
464
465 pub fn touch_updated_at(&mut self) {
466 self.updated_at = Utc::now();
467 }
468
469 pub fn advance_prompt_id(&mut self) {
470 self.last_prompt_id = PromptId::new();
471 }
472
473 pub fn summary(&self) -> Option<SharedString> {
474 self.summary.clone()
475 }
476
477 pub fn project_context(&self) -> SharedProjectContext {
478 self.project_context.clone()
479 }
480
481 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
482
483 pub fn summary_or_default(&self) -> SharedString {
484 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
485 }
486
487 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
488 let Some(current_summary) = &self.summary else {
489 // Don't allow setting summary until generated
490 return;
491 };
492
493 let mut new_summary = new_summary.into();
494
495 if new_summary.is_empty() {
496 new_summary = Self::DEFAULT_SUMMARY;
497 }
498
499 if current_summary != &new_summary {
500 self.summary = Some(new_summary);
501 cx.emit(ThreadEvent::SummaryChanged);
502 }
503 }
504
505 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
506 self.latest_detailed_summary()
507 .unwrap_or_else(|| self.text().into())
508 }
509
510 fn latest_detailed_summary(&self) -> Option<SharedString> {
511 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
512 Some(text.clone())
513 } else {
514 None
515 }
516 }
517
518 pub fn message(&self, id: MessageId) -> Option<&Message> {
519 self.messages.iter().find(|message| message.id == id)
520 }
521
522 pub fn messages(&self) -> impl Iterator<Item = &Message> {
523 self.messages.iter()
524 }
525
526 pub fn is_generating(&self) -> bool {
527 !self.pending_completions.is_empty() || !self.all_tools_finished()
528 }
529
530 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
531 &self.tools
532 }
533
534 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
535 self.tool_use
536 .pending_tool_uses()
537 .into_iter()
538 .find(|tool_use| &tool_use.id == id)
539 }
540
541 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
542 self.tool_use
543 .pending_tool_uses()
544 .into_iter()
545 .filter(|tool_use| tool_use.status.needs_confirmation())
546 }
547
548 pub fn has_pending_tool_uses(&self) -> bool {
549 !self.tool_use.pending_tool_uses().is_empty()
550 }
551
552 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
553 self.checkpoints_by_message.get(&id).cloned()
554 }
555
556 pub fn restore_checkpoint(
557 &mut self,
558 checkpoint: ThreadCheckpoint,
559 cx: &mut Context<Self>,
560 ) -> Task<Result<()>> {
561 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
562 message_id: checkpoint.message_id,
563 });
564 cx.emit(ThreadEvent::CheckpointChanged);
565 cx.notify();
566
567 let git_store = self.project().read(cx).git_store().clone();
568 let restore = git_store.update(cx, |git_store, cx| {
569 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
570 });
571
572 cx.spawn(async move |this, cx| {
573 let result = restore.await;
574 this.update(cx, |this, cx| {
575 if let Err(err) = result.as_ref() {
576 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
577 message_id: checkpoint.message_id,
578 error: err.to_string(),
579 });
580 } else {
581 this.truncate(checkpoint.message_id, cx);
582 this.last_restore_checkpoint = None;
583 }
584 this.pending_checkpoint = None;
585 cx.emit(ThreadEvent::CheckpointChanged);
586 cx.notify();
587 })?;
588 result
589 })
590 }
591
592 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
593 let pending_checkpoint = if self.is_generating() {
594 return;
595 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
596 checkpoint
597 } else {
598 return;
599 };
600
601 let git_store = self.project.read(cx).git_store().clone();
602 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
603 cx.spawn(async move |this, cx| match final_checkpoint.await {
604 Ok(final_checkpoint) => {
605 let equal = git_store
606 .update(cx, |store, cx| {
607 store.compare_checkpoints(
608 pending_checkpoint.git_checkpoint.clone(),
609 final_checkpoint.clone(),
610 cx,
611 )
612 })?
613 .await
614 .unwrap_or(false);
615
616 if equal {
617 git_store
618 .update(cx, |store, cx| {
619 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
620 })?
621 .detach();
622 } else {
623 this.update(cx, |this, cx| {
624 this.insert_checkpoint(pending_checkpoint, cx)
625 })?;
626 }
627
628 git_store
629 .update(cx, |store, cx| {
630 store.delete_checkpoint(final_checkpoint, cx)
631 })?
632 .detach();
633
634 Ok(())
635 }
636 Err(_) => this.update(cx, |this, cx| {
637 this.insert_checkpoint(pending_checkpoint, cx)
638 }),
639 })
640 .detach();
641 }
642
643 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
644 self.checkpoints_by_message
645 .insert(checkpoint.message_id, checkpoint);
646 cx.emit(ThreadEvent::CheckpointChanged);
647 cx.notify();
648 }
649
650 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
651 self.last_restore_checkpoint.as_ref()
652 }
653
654 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
655 let Some(message_ix) = self
656 .messages
657 .iter()
658 .rposition(|message| message.id == message_id)
659 else {
660 return;
661 };
662 for deleted_message in self.messages.drain(message_ix..) {
663 self.context_by_message.remove(&deleted_message.id);
664 self.checkpoints_by_message.remove(&deleted_message.id);
665 }
666 cx.notify();
667 }
668
669 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
670 self.context_by_message
671 .get(&id)
672 .into_iter()
673 .flat_map(|context| {
674 context
675 .iter()
676 .filter_map(|context_id| self.context.get(&context_id))
677 })
678 }
679
680 /// Returns whether all of the tool uses have finished running.
681 pub fn all_tools_finished(&self) -> bool {
682 // If the only pending tool uses left are the ones with errors, then
683 // that means that we've finished running all of the pending tools.
684 self.tool_use
685 .pending_tool_uses()
686 .iter()
687 .all(|tool_use| tool_use.status.is_error())
688 }
689
690 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
691 self.tool_use.tool_uses_for_message(id, cx)
692 }
693
694 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
695 self.tool_use.tool_results_for_message(id)
696 }
697
698 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
699 self.tool_use.tool_result(id)
700 }
701
702 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
703 Some(&self.tool_use.tool_result(id)?.content)
704 }
705
706 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
707 self.tool_use.tool_result_card(id).cloned()
708 }
709
710 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
711 self.tool_use.message_has_tool_results(message_id)
712 }
713
714 /// Filter out contexts that have already been included in previous messages
715 pub fn filter_new_context<'a>(
716 &self,
717 context: impl Iterator<Item = &'a AssistantContext>,
718 ) -> impl Iterator<Item = &'a AssistantContext> {
719 context.filter(|ctx| self.is_context_new(ctx))
720 }
721
722 fn is_context_new(&self, context: &AssistantContext) -> bool {
723 !self.context.contains_key(&context.id())
724 }
725
726 pub fn insert_user_message(
727 &mut self,
728 text: impl Into<String>,
729 context: Vec<AssistantContext>,
730 git_checkpoint: Option<GitStoreCheckpoint>,
731 cx: &mut Context<Self>,
732 ) -> MessageId {
733 let text = text.into();
734
735 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
736
737 let new_context: Vec<_> = context
738 .into_iter()
739 .filter(|ctx| self.is_context_new(ctx))
740 .collect();
741
742 if !new_context.is_empty() {
743 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
744 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
745 message.context = context_string;
746 }
747 }
748
749 self.action_log.update(cx, |log, cx| {
750 // Track all buffers added as context
751 for ctx in &new_context {
752 match ctx {
753 AssistantContext::File(file_ctx) => {
754 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
755 }
756 AssistantContext::Directory(dir_ctx) => {
757 for context_buffer in &dir_ctx.context_buffers {
758 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
759 }
760 }
761 AssistantContext::Symbol(symbol_ctx) => {
762 log.buffer_added_as_context(
763 symbol_ctx.context_symbol.buffer.clone(),
764 cx,
765 );
766 }
767 AssistantContext::Excerpt(excerpt_context) => {
768 log.buffer_added_as_context(
769 excerpt_context.context_buffer.buffer.clone(),
770 cx,
771 );
772 }
773 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
774 }
775 }
776 });
777 }
778
779 let context_ids = new_context
780 .iter()
781 .map(|context| context.id())
782 .collect::<Vec<_>>();
783 self.context.extend(
784 new_context
785 .into_iter()
786 .map(|context| (context.id(), context)),
787 );
788 self.context_by_message.insert(message_id, context_ids);
789
790 if let Some(git_checkpoint) = git_checkpoint {
791 self.pending_checkpoint = Some(ThreadCheckpoint {
792 message_id,
793 git_checkpoint,
794 });
795 }
796
797 self.auto_capture_telemetry(cx);
798
799 message_id
800 }
801
802 pub fn insert_message(
803 &mut self,
804 role: Role,
805 segments: Vec<MessageSegment>,
806 cx: &mut Context<Self>,
807 ) -> MessageId {
808 let id = self.next_message_id.post_inc();
809 self.messages.push(Message {
810 id,
811 role,
812 segments,
813 context: String::new(),
814 });
815 self.touch_updated_at();
816 cx.emit(ThreadEvent::MessageAdded(id));
817 id
818 }
819
820 pub fn edit_message(
821 &mut self,
822 id: MessageId,
823 new_role: Role,
824 new_segments: Vec<MessageSegment>,
825 cx: &mut Context<Self>,
826 ) -> bool {
827 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
828 return false;
829 };
830 message.role = new_role;
831 message.segments = new_segments;
832 self.touch_updated_at();
833 cx.emit(ThreadEvent::MessageEdited(id));
834 true
835 }
836
837 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
838 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
839 return false;
840 };
841 self.messages.remove(index);
842 self.context_by_message.remove(&id);
843 self.touch_updated_at();
844 cx.emit(ThreadEvent::MessageDeleted(id));
845 true
846 }
847
848 /// Returns the representation of this [`Thread`] in a textual form.
849 ///
850 /// This is the representation we use when attaching a thread as context to another thread.
851 pub fn text(&self) -> String {
852 let mut text = String::new();
853
854 for message in &self.messages {
855 text.push_str(match message.role {
856 language_model::Role::User => "User:",
857 language_model::Role::Assistant => "Assistant:",
858 language_model::Role::System => "System:",
859 });
860 text.push('\n');
861
862 for segment in &message.segments {
863 match segment {
864 MessageSegment::Text(content) => text.push_str(content),
865 MessageSegment::Thinking(content) => {
866 text.push_str(&format!("<think>{}</think>", content))
867 }
868 }
869 }
870 text.push('\n');
871 }
872
873 text
874 }
875
876 /// Serializes this thread into a format for storage or telemetry.
877 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
878 let initial_project_snapshot = self.initial_project_snapshot.clone();
879 cx.spawn(async move |this, cx| {
880 let initial_project_snapshot = initial_project_snapshot.await;
881 this.read_with(cx, |this, cx| SerializedThread {
882 version: SerializedThread::VERSION.to_string(),
883 summary: this.summary_or_default(),
884 updated_at: this.updated_at(),
885 messages: this
886 .messages()
887 .map(|message| SerializedMessage {
888 id: message.id,
889 role: message.role,
890 segments: message
891 .segments
892 .iter()
893 .map(|segment| match segment {
894 MessageSegment::Text(text) => {
895 SerializedMessageSegment::Text { text: text.clone() }
896 }
897 MessageSegment::Thinking(text) => {
898 SerializedMessageSegment::Thinking { text: text.clone() }
899 }
900 })
901 .collect(),
902 tool_uses: this
903 .tool_uses_for_message(message.id, cx)
904 .into_iter()
905 .map(|tool_use| SerializedToolUse {
906 id: tool_use.id,
907 name: tool_use.name,
908 input: tool_use.input,
909 })
910 .collect(),
911 tool_results: this
912 .tool_results_for_message(message.id)
913 .into_iter()
914 .map(|tool_result| SerializedToolResult {
915 tool_use_id: tool_result.tool_use_id.clone(),
916 is_error: tool_result.is_error,
917 content: tool_result.content.clone(),
918 })
919 .collect(),
920 context: message.context.clone(),
921 })
922 .collect(),
923 initial_project_snapshot,
924 cumulative_token_usage: this.cumulative_token_usage,
925 request_token_usage: this.request_token_usage.clone(),
926 detailed_summary_state: this.detailed_summary_state.clone(),
927 exceeded_window_error: this.exceeded_window_error.clone(),
928 })
929 })
930 }
931
932 pub fn send_to_model(
933 &mut self,
934 model: Arc<dyn LanguageModel>,
935 request_kind: RequestKind,
936 cx: &mut Context<Self>,
937 ) {
938 let mut request = self.to_completion_request(request_kind, cx);
939 if model.supports_tools() {
940 request.tools = {
941 let mut tools = Vec::new();
942 tools.extend(
943 self.tools()
944 .read(cx)
945 .enabled_tools(cx)
946 .into_iter()
947 .filter_map(|tool| {
948 // Skip tools that cannot be supported
949 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
950 Some(LanguageModelRequestTool {
951 name: tool.name(),
952 description: tool.description(),
953 input_schema,
954 })
955 }),
956 );
957
958 tools
959 };
960 }
961
962 self.stream_completion(request, model, cx);
963 }
964
965 pub fn used_tools_since_last_user_message(&self) -> bool {
966 for message in self.messages.iter().rev() {
967 if self.tool_use.message_has_tool_results(message.id) {
968 return true;
969 } else if message.role == Role::User {
970 return false;
971 }
972 }
973
974 false
975 }
976
977 pub fn to_completion_request(
978 &self,
979 request_kind: RequestKind,
980 cx: &mut Context<Self>,
981 ) -> LanguageModelRequest {
982 let mut request = LanguageModelRequest {
983 thread_id: Some(self.id.to_string()),
984 prompt_id: Some(self.last_prompt_id.to_string()),
985 messages: vec![],
986 tools: Vec::new(),
987 stop: Vec::new(),
988 temperature: None,
989 };
990
991 if let Some(project_context) = self.project_context.borrow().as_ref() {
992 match self
993 .prompt_builder
994 .generate_assistant_system_prompt(project_context)
995 {
996 Err(err) => {
997 let message = format!("{err:?}").into();
998 log::error!("{message}");
999 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1000 header: "Error generating system prompt".into(),
1001 message,
1002 }));
1003 }
1004 Ok(system_prompt) => {
1005 request.messages.push(LanguageModelRequestMessage {
1006 role: Role::System,
1007 content: vec![MessageContent::Text(system_prompt)],
1008 cache: true,
1009 });
1010 }
1011 }
1012 } else {
1013 let message = "Context for system prompt unexpectedly not ready.".into();
1014 log::error!("{message}");
1015 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1016 header: "Error generating system prompt".into(),
1017 message,
1018 }));
1019 }
1020
1021 for message in &self.messages {
1022 let mut request_message = LanguageModelRequestMessage {
1023 role: message.role,
1024 content: Vec::new(),
1025 cache: false,
1026 };
1027
1028 match request_kind {
1029 RequestKind::Chat => {
1030 self.tool_use
1031 .attach_tool_results(message.id, &mut request_message);
1032 }
1033 RequestKind::Summarize => {
1034 // We don't care about tool use during summarization.
1035 if self.tool_use.message_has_tool_results(message.id) {
1036 continue;
1037 }
1038 }
1039 }
1040
1041 if !message.segments.is_empty() {
1042 request_message
1043 .content
1044 .push(MessageContent::Text(message.to_string()));
1045 }
1046
1047 match request_kind {
1048 RequestKind::Chat => {
1049 self.tool_use
1050 .attach_tool_uses(message.id, &mut request_message);
1051 }
1052 RequestKind::Summarize => {
1053 // We don't care about tool use during summarization.
1054 }
1055 };
1056
1057 request.messages.push(request_message);
1058 }
1059
1060 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1061 if let Some(last) = request.messages.last_mut() {
1062 last.cache = true;
1063 }
1064
1065 self.attached_tracked_files_state(&mut request.messages, cx);
1066
1067 request
1068 }
1069
1070 fn attached_tracked_files_state(
1071 &self,
1072 messages: &mut Vec<LanguageModelRequestMessage>,
1073 cx: &App,
1074 ) {
1075 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1076
1077 let mut stale_message = String::new();
1078
1079 let action_log = self.action_log.read(cx);
1080
1081 for stale_file in action_log.stale_buffers(cx) {
1082 let Some(file) = stale_file.read(cx).file() else {
1083 continue;
1084 };
1085
1086 if stale_message.is_empty() {
1087 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1088 }
1089
1090 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1091 }
1092
1093 let mut content = Vec::with_capacity(2);
1094
1095 if !stale_message.is_empty() {
1096 content.push(stale_message.into());
1097 }
1098
1099 if !content.is_empty() {
1100 let context_message = LanguageModelRequestMessage {
1101 role: Role::User,
1102 content,
1103 cache: false,
1104 };
1105
1106 messages.push(context_message);
1107 }
1108 }
1109
1110 pub fn stream_completion(
1111 &mut self,
1112 request: LanguageModelRequest,
1113 model: Arc<dyn LanguageModel>,
1114 cx: &mut Context<Self>,
1115 ) {
1116 let pending_completion_id = post_inc(&mut self.completion_count);
1117 let mut request_callback_parameters = if self.request_callback.is_some() {
1118 Some((request.clone(), Vec::new()))
1119 } else {
1120 None
1121 };
1122 let prompt_id = self.last_prompt_id.clone();
1123 let task = cx.spawn(async move |thread, cx| {
1124 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1125 let initial_token_usage =
1126 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1127 let stream_completion = async {
1128 let (mut events, usage) = stream_completion_future.await?;
1129
1130 let mut stop_reason = StopReason::EndTurn;
1131 let mut current_token_usage = TokenUsage::default();
1132
1133 if let Some(usage) = usage {
1134 thread
1135 .update(cx, |_thread, cx| {
1136 cx.emit(ThreadEvent::UsageUpdated(usage));
1137 })
1138 .ok();
1139 }
1140
1141 while let Some(event) = events.next().await {
1142 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1143 response_events
1144 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1145 }
1146
1147 let event = event?;
1148
1149 thread.update(cx, |thread, cx| {
1150 match event {
1151 LanguageModelCompletionEvent::StartMessage { .. } => {
1152 thread.insert_message(
1153 Role::Assistant,
1154 vec![MessageSegment::Text(String::new())],
1155 cx,
1156 );
1157 }
1158 LanguageModelCompletionEvent::Stop(reason) => {
1159 stop_reason = reason;
1160 }
1161 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1162 thread.update_token_usage_at_last_message(token_usage);
1163 thread.cumulative_token_usage = thread.cumulative_token_usage
1164 + token_usage
1165 - current_token_usage;
1166 current_token_usage = token_usage;
1167 }
1168 LanguageModelCompletionEvent::Text(chunk) => {
1169 if let Some(last_message) = thread.messages.last_mut() {
1170 if last_message.role == Role::Assistant {
1171 last_message.push_text(&chunk);
1172 cx.emit(ThreadEvent::StreamedAssistantText(
1173 last_message.id,
1174 chunk,
1175 ));
1176 } else {
1177 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1178 // of a new Assistant response.
1179 //
1180 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1181 // will result in duplicating the text of the chunk in the rendered Markdown.
1182 thread.insert_message(
1183 Role::Assistant,
1184 vec![MessageSegment::Text(chunk.to_string())],
1185 cx,
1186 );
1187 };
1188 }
1189 }
1190 LanguageModelCompletionEvent::Thinking(chunk) => {
1191 if let Some(last_message) = thread.messages.last_mut() {
1192 if last_message.role == Role::Assistant {
1193 last_message.push_thinking(&chunk);
1194 cx.emit(ThreadEvent::StreamedAssistantThinking(
1195 last_message.id,
1196 chunk,
1197 ));
1198 } else {
1199 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1200 // of a new Assistant response.
1201 //
1202 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1203 // will result in duplicating the text of the chunk in the rendered Markdown.
1204 thread.insert_message(
1205 Role::Assistant,
1206 vec![MessageSegment::Thinking(chunk.to_string())],
1207 cx,
1208 );
1209 };
1210 }
1211 }
1212 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1213 let last_assistant_message_id = thread
1214 .messages
1215 .iter_mut()
1216 .rfind(|message| message.role == Role::Assistant)
1217 .map(|message| message.id)
1218 .unwrap_or_else(|| {
1219 thread.insert_message(Role::Assistant, vec![], cx)
1220 });
1221
1222 thread.tool_use.request_tool_use(
1223 last_assistant_message_id,
1224 tool_use,
1225 cx,
1226 );
1227 }
1228 }
1229
1230 thread.touch_updated_at();
1231 cx.emit(ThreadEvent::StreamedCompletion);
1232 cx.notify();
1233
1234 thread.auto_capture_telemetry(cx);
1235 })?;
1236
1237 smol::future::yield_now().await;
1238 }
1239
1240 thread.update(cx, |thread, cx| {
1241 thread
1242 .pending_completions
1243 .retain(|completion| completion.id != pending_completion_id);
1244
1245 if thread.summary.is_none() && thread.messages.len() >= 2 {
1246 thread.summarize(cx);
1247 }
1248 })?;
1249
1250 anyhow::Ok(stop_reason)
1251 };
1252
1253 let result = stream_completion.await;
1254
1255 thread
1256 .update(cx, |thread, cx| {
1257 thread.finalize_pending_checkpoint(cx);
1258 match result.as_ref() {
1259 Ok(stop_reason) => match stop_reason {
1260 StopReason::ToolUse => {
1261 let tool_uses = thread.use_pending_tools(cx);
1262 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1263 }
1264 StopReason::EndTurn => {}
1265 StopReason::MaxTokens => {}
1266 },
1267 Err(error) => {
1268 if error.is::<PaymentRequiredError>() {
1269 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1270 } else if error.is::<MaxMonthlySpendReachedError>() {
1271 cx.emit(ThreadEvent::ShowError(
1272 ThreadError::MaxMonthlySpendReached,
1273 ));
1274 } else if let Some(error) =
1275 error.downcast_ref::<ModelRequestLimitReachedError>()
1276 {
1277 cx.emit(ThreadEvent::ShowError(
1278 ThreadError::ModelRequestLimitReached { plan: error.plan },
1279 ));
1280 } else if let Some(known_error) =
1281 error.downcast_ref::<LanguageModelKnownError>()
1282 {
1283 match known_error {
1284 LanguageModelKnownError::ContextWindowLimitExceeded {
1285 tokens,
1286 } => {
1287 thread.exceeded_window_error = Some(ExceededWindowError {
1288 model_id: model.id(),
1289 token_count: *tokens,
1290 });
1291 cx.notify();
1292 }
1293 }
1294 } else {
1295 let error_message = error
1296 .chain()
1297 .map(|err| err.to_string())
1298 .collect::<Vec<_>>()
1299 .join("\n");
1300 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1301 header: "Error interacting with language model".into(),
1302 message: SharedString::from(error_message.clone()),
1303 }));
1304 }
1305
1306 thread.cancel_last_completion(cx);
1307 }
1308 }
1309 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1310
1311 if let Some((request_callback, (request, response_events))) = thread
1312 .request_callback
1313 .as_mut()
1314 .zip(request_callback_parameters.as_ref())
1315 {
1316 request_callback(request, response_events);
1317 }
1318
1319 thread.auto_capture_telemetry(cx);
1320
1321 if let Ok(initial_usage) = initial_token_usage {
1322 let usage = thread.cumulative_token_usage - initial_usage;
1323
1324 telemetry::event!(
1325 "Assistant Thread Completion",
1326 thread_id = thread.id().to_string(),
1327 prompt_id = prompt_id,
1328 model = model.telemetry_id(),
1329 model_provider = model.provider_id().to_string(),
1330 input_tokens = usage.input_tokens,
1331 output_tokens = usage.output_tokens,
1332 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1333 cache_read_input_tokens = usage.cache_read_input_tokens,
1334 );
1335 }
1336 })
1337 .ok();
1338 });
1339
1340 self.pending_completions.push(PendingCompletion {
1341 id: pending_completion_id,
1342 _task: task,
1343 });
1344 }
1345
1346 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1347 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1348 return;
1349 };
1350
1351 if !model.provider.is_authenticated(cx) {
1352 return;
1353 }
1354
1355 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1356 request.messages.push(LanguageModelRequestMessage {
1357 role: Role::User,
1358 content: vec![
1359 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1360 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1361 If the conversation is about a specific subject, include it in the title. \
1362 Be descriptive. DO NOT speak in the first person."
1363 .into(),
1364 ],
1365 cache: false,
1366 });
1367
1368 self.pending_summary = cx.spawn(async move |this, cx| {
1369 async move {
1370 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1371 let (mut messages, usage) = stream.await?;
1372
1373 if let Some(usage) = usage {
1374 this.update(cx, |_thread, cx| {
1375 cx.emit(ThreadEvent::UsageUpdated(usage));
1376 })
1377 .ok();
1378 }
1379
1380 let mut new_summary = String::new();
1381 while let Some(message) = messages.stream.next().await {
1382 let text = message?;
1383 let mut lines = text.lines();
1384 new_summary.extend(lines.next());
1385
1386 // Stop if the LLM generated multiple lines.
1387 if lines.next().is_some() {
1388 break;
1389 }
1390 }
1391
1392 this.update(cx, |this, cx| {
1393 if !new_summary.is_empty() {
1394 this.summary = Some(new_summary.into());
1395 }
1396
1397 cx.emit(ThreadEvent::SummaryGenerated);
1398 })?;
1399
1400 anyhow::Ok(())
1401 }
1402 .log_err()
1403 .await
1404 });
1405 }
1406
1407 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1408 let last_message_id = self.messages.last().map(|message| message.id)?;
1409
1410 match &self.detailed_summary_state {
1411 DetailedSummaryState::Generating { message_id, .. }
1412 | DetailedSummaryState::Generated { message_id, .. }
1413 if *message_id == last_message_id =>
1414 {
1415 // Already up-to-date
1416 return None;
1417 }
1418 _ => {}
1419 }
1420
1421 let ConfiguredModel { model, provider } =
1422 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1423
1424 if !provider.is_authenticated(cx) {
1425 return None;
1426 }
1427
1428 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1429
1430 request.messages.push(LanguageModelRequestMessage {
1431 role: Role::User,
1432 content: vec![
1433 "Generate a detailed summary of this conversation. Include:\n\
1434 1. A brief overview of what was discussed\n\
1435 2. Key facts or information discovered\n\
1436 3. Outcomes or conclusions reached\n\
1437 4. Any action items or next steps if any\n\
1438 Format it in Markdown with headings and bullet points."
1439 .into(),
1440 ],
1441 cache: false,
1442 });
1443
1444 let task = cx.spawn(async move |thread, cx| {
1445 let stream = model.stream_completion_text(request, &cx);
1446 let Some(mut messages) = stream.await.log_err() else {
1447 thread
1448 .update(cx, |this, _cx| {
1449 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1450 })
1451 .log_err();
1452
1453 return;
1454 };
1455
1456 let mut new_detailed_summary = String::new();
1457
1458 while let Some(chunk) = messages.stream.next().await {
1459 if let Some(chunk) = chunk.log_err() {
1460 new_detailed_summary.push_str(&chunk);
1461 }
1462 }
1463
1464 thread
1465 .update(cx, |this, _cx| {
1466 this.detailed_summary_state = DetailedSummaryState::Generated {
1467 text: new_detailed_summary.into(),
1468 message_id: last_message_id,
1469 };
1470 })
1471 .log_err();
1472 });
1473
1474 self.detailed_summary_state = DetailedSummaryState::Generating {
1475 message_id: last_message_id,
1476 };
1477
1478 Some(task)
1479 }
1480
1481 pub fn is_generating_detailed_summary(&self) -> bool {
1482 matches!(
1483 self.detailed_summary_state,
1484 DetailedSummaryState::Generating { .. }
1485 )
1486 }
1487
1488 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1489 self.auto_capture_telemetry(cx);
1490 let request = self.to_completion_request(RequestKind::Chat, cx);
1491 let messages = Arc::new(request.messages);
1492 let pending_tool_uses = self
1493 .tool_use
1494 .pending_tool_uses()
1495 .into_iter()
1496 .filter(|tool_use| tool_use.status.is_idle())
1497 .cloned()
1498 .collect::<Vec<_>>();
1499
1500 for tool_use in pending_tool_uses.iter() {
1501 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1502 if tool.needs_confirmation(&tool_use.input, cx)
1503 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1504 {
1505 self.tool_use.confirm_tool_use(
1506 tool_use.id.clone(),
1507 tool_use.ui_text.clone(),
1508 tool_use.input.clone(),
1509 messages.clone(),
1510 tool,
1511 );
1512 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1513 } else {
1514 self.run_tool(
1515 tool_use.id.clone(),
1516 tool_use.ui_text.clone(),
1517 tool_use.input.clone(),
1518 &messages,
1519 tool,
1520 cx,
1521 );
1522 }
1523 }
1524 }
1525
1526 pending_tool_uses
1527 }
1528
1529 pub fn run_tool(
1530 &mut self,
1531 tool_use_id: LanguageModelToolUseId,
1532 ui_text: impl Into<SharedString>,
1533 input: serde_json::Value,
1534 messages: &[LanguageModelRequestMessage],
1535 tool: Arc<dyn Tool>,
1536 cx: &mut Context<Thread>,
1537 ) {
1538 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1539 self.tool_use
1540 .run_pending_tool(tool_use_id, ui_text.into(), task);
1541 }
1542
1543 fn spawn_tool_use(
1544 &mut self,
1545 tool_use_id: LanguageModelToolUseId,
1546 messages: &[LanguageModelRequestMessage],
1547 input: serde_json::Value,
1548 tool: Arc<dyn Tool>,
1549 cx: &mut Context<Thread>,
1550 ) -> Task<()> {
1551 let tool_name: Arc<str> = tool.name().into();
1552
1553 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1554 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1555 } else {
1556 tool.run(
1557 input,
1558 messages,
1559 self.project.clone(),
1560 self.action_log.clone(),
1561 cx,
1562 )
1563 };
1564
1565 // Store the card separately if it exists
1566 if let Some(card) = tool_result.card.clone() {
1567 self.tool_use
1568 .insert_tool_result_card(tool_use_id.clone(), card);
1569 }
1570
1571 cx.spawn({
1572 async move |thread: WeakEntity<Thread>, cx| {
1573 let output = tool_result.output.await;
1574
1575 thread
1576 .update(cx, |thread, cx| {
1577 let pending_tool_use = thread.tool_use.insert_tool_output(
1578 tool_use_id.clone(),
1579 tool_name,
1580 output,
1581 cx,
1582 );
1583 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1584 })
1585 .ok();
1586 }
1587 })
1588 }
1589
1590 fn tool_finished(
1591 &mut self,
1592 tool_use_id: LanguageModelToolUseId,
1593 pending_tool_use: Option<PendingToolUse>,
1594 canceled: bool,
1595 cx: &mut Context<Self>,
1596 ) {
1597 if self.all_tools_finished() {
1598 let model_registry = LanguageModelRegistry::read_global(cx);
1599 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1600 self.attach_tool_results(cx);
1601 if !canceled {
1602 self.send_to_model(model, RequestKind::Chat, cx);
1603 }
1604 }
1605 }
1606
1607 cx.emit(ThreadEvent::ToolFinished {
1608 tool_use_id,
1609 pending_tool_use,
1610 });
1611 }
1612
1613 /// Insert an empty message to be populated with tool results upon send.
1614 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1615 // TODO: Don't insert a dummy user message here. Ensure this works with the thinking model.
1616 // Insert a user message to contain the tool results.
1617 self.insert_user_message("Here are the tool results.", Vec::new(), None, cx);
1618 }
1619
1620 /// Cancels the last pending completion, if there are any pending.
1621 ///
1622 /// Returns whether a completion was canceled.
1623 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1624 let canceled = if self.pending_completions.pop().is_some() {
1625 true
1626 } else {
1627 let mut canceled = false;
1628 for pending_tool_use in self.tool_use.cancel_pending() {
1629 canceled = true;
1630 self.tool_finished(
1631 pending_tool_use.id.clone(),
1632 Some(pending_tool_use),
1633 true,
1634 cx,
1635 );
1636 }
1637 canceled
1638 };
1639 self.finalize_pending_checkpoint(cx);
1640 canceled
1641 }
1642
1643 pub fn feedback(&self) -> Option<ThreadFeedback> {
1644 self.feedback
1645 }
1646
1647 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1648 self.message_feedback.get(&message_id).copied()
1649 }
1650
1651 pub fn report_message_feedback(
1652 &mut self,
1653 message_id: MessageId,
1654 feedback: ThreadFeedback,
1655 cx: &mut Context<Self>,
1656 ) -> Task<Result<()>> {
1657 if self.message_feedback.get(&message_id) == Some(&feedback) {
1658 return Task::ready(Ok(()));
1659 }
1660
1661 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1662 let serialized_thread = self.serialize(cx);
1663 let thread_id = self.id().clone();
1664 let client = self.project.read(cx).client();
1665
1666 let enabled_tool_names: Vec<String> = self
1667 .tools()
1668 .read(cx)
1669 .enabled_tools(cx)
1670 .iter()
1671 .map(|tool| tool.name().to_string())
1672 .collect();
1673
1674 self.message_feedback.insert(message_id, feedback);
1675
1676 cx.notify();
1677
1678 let message_content = self
1679 .message(message_id)
1680 .map(|msg| msg.to_string())
1681 .unwrap_or_default();
1682
1683 cx.background_spawn(async move {
1684 let final_project_snapshot = final_project_snapshot.await;
1685 let serialized_thread = serialized_thread.await?;
1686 let thread_data =
1687 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1688
1689 let rating = match feedback {
1690 ThreadFeedback::Positive => "positive",
1691 ThreadFeedback::Negative => "negative",
1692 };
1693 telemetry::event!(
1694 "Assistant Thread Rated",
1695 rating,
1696 thread_id,
1697 enabled_tool_names,
1698 message_id = message_id.0,
1699 message_content,
1700 thread_data,
1701 final_project_snapshot
1702 );
1703 client.telemetry().flush_events();
1704
1705 Ok(())
1706 })
1707 }
1708
1709 pub fn report_feedback(
1710 &mut self,
1711 feedback: ThreadFeedback,
1712 cx: &mut Context<Self>,
1713 ) -> Task<Result<()>> {
1714 let last_assistant_message_id = self
1715 .messages
1716 .iter()
1717 .rev()
1718 .find(|msg| msg.role == Role::Assistant)
1719 .map(|msg| msg.id);
1720
1721 if let Some(message_id) = last_assistant_message_id {
1722 self.report_message_feedback(message_id, feedback, cx)
1723 } else {
1724 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1725 let serialized_thread = self.serialize(cx);
1726 let thread_id = self.id().clone();
1727 let client = self.project.read(cx).client();
1728 self.feedback = Some(feedback);
1729 cx.notify();
1730
1731 cx.background_spawn(async move {
1732 let final_project_snapshot = final_project_snapshot.await;
1733 let serialized_thread = serialized_thread.await?;
1734 let thread_data = serde_json::to_value(serialized_thread)
1735 .unwrap_or_else(|_| serde_json::Value::Null);
1736
1737 let rating = match feedback {
1738 ThreadFeedback::Positive => "positive",
1739 ThreadFeedback::Negative => "negative",
1740 };
1741 telemetry::event!(
1742 "Assistant Thread Rated",
1743 rating,
1744 thread_id,
1745 thread_data,
1746 final_project_snapshot
1747 );
1748 client.telemetry().flush_events();
1749
1750 Ok(())
1751 })
1752 }
1753 }
1754
1755 /// Create a snapshot of the current project state including git information and unsaved buffers.
1756 fn project_snapshot(
1757 project: Entity<Project>,
1758 cx: &mut Context<Self>,
1759 ) -> Task<Arc<ProjectSnapshot>> {
1760 let git_store = project.read(cx).git_store().clone();
1761 let worktree_snapshots: Vec<_> = project
1762 .read(cx)
1763 .visible_worktrees(cx)
1764 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1765 .collect();
1766
1767 cx.spawn(async move |_, cx| {
1768 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1769
1770 let mut unsaved_buffers = Vec::new();
1771 cx.update(|app_cx| {
1772 let buffer_store = project.read(app_cx).buffer_store();
1773 for buffer_handle in buffer_store.read(app_cx).buffers() {
1774 let buffer = buffer_handle.read(app_cx);
1775 if buffer.is_dirty() {
1776 if let Some(file) = buffer.file() {
1777 let path = file.path().to_string_lossy().to_string();
1778 unsaved_buffers.push(path);
1779 }
1780 }
1781 }
1782 })
1783 .ok();
1784
1785 Arc::new(ProjectSnapshot {
1786 worktree_snapshots,
1787 unsaved_buffer_paths: unsaved_buffers,
1788 timestamp: Utc::now(),
1789 })
1790 })
1791 }
1792
1793 fn worktree_snapshot(
1794 worktree: Entity<project::Worktree>,
1795 git_store: Entity<GitStore>,
1796 cx: &App,
1797 ) -> Task<WorktreeSnapshot> {
1798 cx.spawn(async move |cx| {
1799 // Get worktree path and snapshot
1800 let worktree_info = cx.update(|app_cx| {
1801 let worktree = worktree.read(app_cx);
1802 let path = worktree.abs_path().to_string_lossy().to_string();
1803 let snapshot = worktree.snapshot();
1804 (path, snapshot)
1805 });
1806
1807 let Ok((worktree_path, _snapshot)) = worktree_info else {
1808 return WorktreeSnapshot {
1809 worktree_path: String::new(),
1810 git_state: None,
1811 };
1812 };
1813
1814 let git_state = git_store
1815 .update(cx, |git_store, cx| {
1816 git_store
1817 .repositories()
1818 .values()
1819 .find(|repo| {
1820 repo.read(cx)
1821 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1822 .is_some()
1823 })
1824 .cloned()
1825 })
1826 .ok()
1827 .flatten()
1828 .map(|repo| {
1829 repo.update(cx, |repo, _| {
1830 let current_branch =
1831 repo.branch.as_ref().map(|branch| branch.name.to_string());
1832 repo.send_job(None, |state, _| async move {
1833 let RepositoryState::Local { backend, .. } = state else {
1834 return GitState {
1835 remote_url: None,
1836 head_sha: None,
1837 current_branch,
1838 diff: None,
1839 };
1840 };
1841
1842 let remote_url = backend.remote_url("origin");
1843 let head_sha = backend.head_sha();
1844 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1845
1846 GitState {
1847 remote_url,
1848 head_sha,
1849 current_branch,
1850 diff,
1851 }
1852 })
1853 })
1854 });
1855
1856 let git_state = match git_state {
1857 Some(git_state) => match git_state.ok() {
1858 Some(git_state) => git_state.await.ok(),
1859 None => None,
1860 },
1861 None => None,
1862 };
1863
1864 WorktreeSnapshot {
1865 worktree_path,
1866 git_state,
1867 }
1868 })
1869 }
1870
1871 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1872 let mut markdown = Vec::new();
1873
1874 if let Some(summary) = self.summary() {
1875 writeln!(markdown, "# {summary}\n")?;
1876 };
1877
1878 for message in self.messages() {
1879 writeln!(
1880 markdown,
1881 "## {role}\n",
1882 role = match message.role {
1883 Role::User => "User",
1884 Role::Assistant => "Assistant",
1885 Role::System => "System",
1886 }
1887 )?;
1888
1889 if !message.context.is_empty() {
1890 writeln!(markdown, "{}", message.context)?;
1891 }
1892
1893 for segment in &message.segments {
1894 match segment {
1895 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1896 MessageSegment::Thinking(text) => {
1897 writeln!(markdown, "<think>{}</think>\n", text)?
1898 }
1899 }
1900 }
1901
1902 for tool_use in self.tool_uses_for_message(message.id, cx) {
1903 writeln!(
1904 markdown,
1905 "**Use Tool: {} ({})**",
1906 tool_use.name, tool_use.id
1907 )?;
1908 writeln!(markdown, "```json")?;
1909 writeln!(
1910 markdown,
1911 "{}",
1912 serde_json::to_string_pretty(&tool_use.input)?
1913 )?;
1914 writeln!(markdown, "```")?;
1915 }
1916
1917 for tool_result in self.tool_results_for_message(message.id) {
1918 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1919 if tool_result.is_error {
1920 write!(markdown, " (Error)")?;
1921 }
1922
1923 writeln!(markdown, "**\n")?;
1924 writeln!(markdown, "{}", tool_result.content)?;
1925 }
1926 }
1927
1928 Ok(String::from_utf8_lossy(&markdown).to_string())
1929 }
1930
1931 pub fn keep_edits_in_range(
1932 &mut self,
1933 buffer: Entity<language::Buffer>,
1934 buffer_range: Range<language::Anchor>,
1935 cx: &mut Context<Self>,
1936 ) {
1937 self.action_log.update(cx, |action_log, cx| {
1938 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1939 });
1940 }
1941
1942 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1943 self.action_log
1944 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1945 }
1946
1947 pub fn reject_edits_in_ranges(
1948 &mut self,
1949 buffer: Entity<language::Buffer>,
1950 buffer_ranges: Vec<Range<language::Anchor>>,
1951 cx: &mut Context<Self>,
1952 ) -> Task<Result<()>> {
1953 self.action_log.update(cx, |action_log, cx| {
1954 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1955 })
1956 }
1957
1958 pub fn action_log(&self) -> &Entity<ActionLog> {
1959 &self.action_log
1960 }
1961
1962 pub fn project(&self) -> &Entity<Project> {
1963 &self.project
1964 }
1965
1966 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1967 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1968 return;
1969 }
1970
1971 let now = Instant::now();
1972 if let Some(last) = self.last_auto_capture_at {
1973 if now.duration_since(last).as_secs() < 10 {
1974 return;
1975 }
1976 }
1977
1978 self.last_auto_capture_at = Some(now);
1979
1980 let thread_id = self.id().clone();
1981 let github_login = self
1982 .project
1983 .read(cx)
1984 .user_store()
1985 .read(cx)
1986 .current_user()
1987 .map(|user| user.github_login.clone());
1988 let client = self.project.read(cx).client().clone();
1989 let serialize_task = self.serialize(cx);
1990
1991 cx.background_executor()
1992 .spawn(async move {
1993 if let Ok(serialized_thread) = serialize_task.await {
1994 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1995 telemetry::event!(
1996 "Agent Thread Auto-Captured",
1997 thread_id = thread_id.to_string(),
1998 thread_data = thread_data,
1999 auto_capture_reason = "tracked_user",
2000 github_login = github_login
2001 );
2002
2003 client.telemetry().flush_events();
2004 }
2005 }
2006 })
2007 .detach();
2008 }
2009
2010 pub fn cumulative_token_usage(&self) -> TokenUsage {
2011 self.cumulative_token_usage
2012 }
2013
2014 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2015 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2016 return TotalTokenUsage::default();
2017 };
2018
2019 let max = model.model.max_token_count();
2020
2021 let index = self
2022 .messages
2023 .iter()
2024 .position(|msg| msg.id == message_id)
2025 .unwrap_or(0);
2026
2027 if index == 0 {
2028 return TotalTokenUsage { total: 0, max };
2029 }
2030
2031 let token_usage = &self
2032 .request_token_usage
2033 .get(index - 1)
2034 .cloned()
2035 .unwrap_or_default();
2036
2037 TotalTokenUsage {
2038 total: token_usage.total_tokens() as usize,
2039 max,
2040 }
2041 }
2042
2043 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2044 let model_registry = LanguageModelRegistry::read_global(cx);
2045 let Some(model) = model_registry.default_model() else {
2046 return TotalTokenUsage::default();
2047 };
2048
2049 let max = model.model.max_token_count();
2050
2051 if let Some(exceeded_error) = &self.exceeded_window_error {
2052 if model.model.id() == exceeded_error.model_id {
2053 return TotalTokenUsage {
2054 total: exceeded_error.token_count,
2055 max,
2056 };
2057 }
2058 }
2059
2060 let total = self
2061 .token_usage_at_last_message()
2062 .unwrap_or_default()
2063 .total_tokens() as usize;
2064
2065 TotalTokenUsage { total, max }
2066 }
2067
2068 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2069 self.request_token_usage
2070 .get(self.messages.len().saturating_sub(1))
2071 .or_else(|| self.request_token_usage.last())
2072 .cloned()
2073 }
2074
2075 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2076 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2077 self.request_token_usage
2078 .resize(self.messages.len(), placeholder);
2079
2080 if let Some(last) = self.request_token_usage.last_mut() {
2081 *last = token_usage;
2082 }
2083 }
2084
2085 pub fn deny_tool_use(
2086 &mut self,
2087 tool_use_id: LanguageModelToolUseId,
2088 tool_name: Arc<str>,
2089 cx: &mut Context<Self>,
2090 ) {
2091 let err = Err(anyhow::anyhow!(
2092 "Permission to run tool action denied by user"
2093 ));
2094
2095 self.tool_use
2096 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2097 self.tool_finished(tool_use_id.clone(), None, true, cx);
2098 }
2099}
2100
2101#[derive(Debug, Clone, Error)]
2102pub enum ThreadError {
2103 #[error("Payment required")]
2104 PaymentRequired,
2105 #[error("Max monthly spend reached")]
2106 MaxMonthlySpendReached,
2107 #[error("Model request limit reached")]
2108 ModelRequestLimitReached { plan: Plan },
2109 #[error("Message {header}: {message}")]
2110 Message {
2111 header: SharedString,
2112 message: SharedString,
2113 },
2114}
2115
2116#[derive(Debug, Clone)]
2117pub enum ThreadEvent {
2118 ShowError(ThreadError),
2119 UsageUpdated(RequestUsage),
2120 StreamedCompletion,
2121 StreamedAssistantText(MessageId, String),
2122 StreamedAssistantThinking(MessageId, String),
2123 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2124 MessageAdded(MessageId),
2125 MessageEdited(MessageId),
2126 MessageDeleted(MessageId),
2127 SummaryGenerated,
2128 SummaryChanged,
2129 UsePendingTools {
2130 tool_uses: Vec<PendingToolUse>,
2131 },
2132 ToolFinished {
2133 #[allow(unused)]
2134 tool_use_id: LanguageModelToolUseId,
2135 /// The pending tool use that corresponds to this tool.
2136 pending_tool_use: Option<PendingToolUse>,
2137 },
2138 CheckpointChanged,
2139 ToolConfirmationNeeded,
2140}
2141
2142impl EventEmitter<ThreadEvent> for Thread {}
2143
2144struct PendingCompletion {
2145 id: usize,
2146 _task: Task<()>,
2147}
2148
2149#[cfg(test)]
2150mod tests {
2151 use super::*;
2152 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2153 use assistant_settings::AssistantSettings;
2154 use context_server::ContextServerSettings;
2155 use editor::EditorSettings;
2156 use gpui::TestAppContext;
2157 use project::{FakeFs, Project};
2158 use prompt_store::PromptBuilder;
2159 use serde_json::json;
2160 use settings::{Settings, SettingsStore};
2161 use std::sync::Arc;
2162 use theme::ThemeSettings;
2163 use util::path;
2164 use workspace::Workspace;
2165
2166 #[gpui::test]
2167 async fn test_message_with_context(cx: &mut TestAppContext) {
2168 init_test_settings(cx);
2169
2170 let project = create_test_project(
2171 cx,
2172 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2173 )
2174 .await;
2175
2176 let (_workspace, _thread_store, thread, context_store) =
2177 setup_test_environment(cx, project.clone()).await;
2178
2179 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2180 .await
2181 .unwrap();
2182
2183 let context =
2184 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2185
2186 // Insert user message with context
2187 let message_id = thread.update(cx, |thread, cx| {
2188 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2189 });
2190
2191 // Check content and context in message object
2192 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2193
2194 // Use different path format strings based on platform for the test
2195 #[cfg(windows)]
2196 let path_part = r"test\code.rs";
2197 #[cfg(not(windows))]
2198 let path_part = "test/code.rs";
2199
2200 let expected_context = format!(
2201 r#"
2202<context>
2203The following items were attached by the user. You don't need to use other tools to read them.
2204
2205<files>
2206```rs {path_part}
2207fn main() {{
2208 println!("Hello, world!");
2209}}
2210```
2211</files>
2212</context>
2213"#
2214 );
2215
2216 assert_eq!(message.role, Role::User);
2217 assert_eq!(message.segments.len(), 1);
2218 assert_eq!(
2219 message.segments[0],
2220 MessageSegment::Text("Please explain this code".to_string())
2221 );
2222 assert_eq!(message.context, expected_context);
2223
2224 // Check message in request
2225 let request = thread.update(cx, |thread, cx| {
2226 thread.to_completion_request(RequestKind::Chat, cx)
2227 });
2228
2229 assert_eq!(request.messages.len(), 2);
2230 let expected_full_message = format!("{}Please explain this code", expected_context);
2231 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2232 }
2233
2234 #[gpui::test]
2235 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2236 init_test_settings(cx);
2237
2238 let project = create_test_project(
2239 cx,
2240 json!({
2241 "file1.rs": "fn function1() {}\n",
2242 "file2.rs": "fn function2() {}\n",
2243 "file3.rs": "fn function3() {}\n",
2244 }),
2245 )
2246 .await;
2247
2248 let (_, _thread_store, thread, context_store) =
2249 setup_test_environment(cx, project.clone()).await;
2250
2251 // Open files individually
2252 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2253 .await
2254 .unwrap();
2255 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2256 .await
2257 .unwrap();
2258 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2259 .await
2260 .unwrap();
2261
2262 // Get the context objects
2263 let contexts = context_store.update(cx, |store, _| store.context().clone());
2264 assert_eq!(contexts.len(), 3);
2265
2266 // First message with context 1
2267 let message1_id = thread.update(cx, |thread, cx| {
2268 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2269 });
2270
2271 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2272 let message2_id = thread.update(cx, |thread, cx| {
2273 thread.insert_user_message(
2274 "Message 2",
2275 vec![contexts[0].clone(), contexts[1].clone()],
2276 None,
2277 cx,
2278 )
2279 });
2280
2281 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2282 let message3_id = thread.update(cx, |thread, cx| {
2283 thread.insert_user_message(
2284 "Message 3",
2285 vec![
2286 contexts[0].clone(),
2287 contexts[1].clone(),
2288 contexts[2].clone(),
2289 ],
2290 None,
2291 cx,
2292 )
2293 });
2294
2295 // Check what contexts are included in each message
2296 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2297 (
2298 thread.message(message1_id).unwrap().clone(),
2299 thread.message(message2_id).unwrap().clone(),
2300 thread.message(message3_id).unwrap().clone(),
2301 )
2302 });
2303
2304 // First message should include context 1
2305 assert!(message1.context.contains("file1.rs"));
2306
2307 // Second message should include only context 2 (not 1)
2308 assert!(!message2.context.contains("file1.rs"));
2309 assert!(message2.context.contains("file2.rs"));
2310
2311 // Third message should include only context 3 (not 1 or 2)
2312 assert!(!message3.context.contains("file1.rs"));
2313 assert!(!message3.context.contains("file2.rs"));
2314 assert!(message3.context.contains("file3.rs"));
2315
2316 // Check entire request to make sure all contexts are properly included
2317 let request = thread.update(cx, |thread, cx| {
2318 thread.to_completion_request(RequestKind::Chat, cx)
2319 });
2320
2321 // The request should contain all 3 messages
2322 assert_eq!(request.messages.len(), 4);
2323
2324 // Check that the contexts are properly formatted in each message
2325 assert!(request.messages[1].string_contents().contains("file1.rs"));
2326 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2327 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2328
2329 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2330 assert!(request.messages[2].string_contents().contains("file2.rs"));
2331 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2332
2333 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2334 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2335 assert!(request.messages[3].string_contents().contains("file3.rs"));
2336 }
2337
2338 #[gpui::test]
2339 async fn test_message_without_files(cx: &mut TestAppContext) {
2340 init_test_settings(cx);
2341
2342 let project = create_test_project(
2343 cx,
2344 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2345 )
2346 .await;
2347
2348 let (_, _thread_store, thread, _context_store) =
2349 setup_test_environment(cx, project.clone()).await;
2350
2351 // Insert user message without any context (empty context vector)
2352 let message_id = thread.update(cx, |thread, cx| {
2353 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2354 });
2355
2356 // Check content and context in message object
2357 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2358
2359 // Context should be empty when no files are included
2360 assert_eq!(message.role, Role::User);
2361 assert_eq!(message.segments.len(), 1);
2362 assert_eq!(
2363 message.segments[0],
2364 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2365 );
2366 assert_eq!(message.context, "");
2367
2368 // Check message in request
2369 let request = thread.update(cx, |thread, cx| {
2370 thread.to_completion_request(RequestKind::Chat, cx)
2371 });
2372
2373 assert_eq!(request.messages.len(), 2);
2374 assert_eq!(
2375 request.messages[1].string_contents(),
2376 "What is the best way to learn Rust?"
2377 );
2378
2379 // Add second message, also without context
2380 let message2_id = thread.update(cx, |thread, cx| {
2381 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2382 });
2383
2384 let message2 =
2385 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2386 assert_eq!(message2.context, "");
2387
2388 // Check that both messages appear in the request
2389 let request = thread.update(cx, |thread, cx| {
2390 thread.to_completion_request(RequestKind::Chat, cx)
2391 });
2392
2393 assert_eq!(request.messages.len(), 3);
2394 assert_eq!(
2395 request.messages[1].string_contents(),
2396 "What is the best way to learn Rust?"
2397 );
2398 assert_eq!(
2399 request.messages[2].string_contents(),
2400 "Are there any good books?"
2401 );
2402 }
2403
2404 #[gpui::test]
2405 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2406 init_test_settings(cx);
2407
2408 let project = create_test_project(
2409 cx,
2410 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2411 )
2412 .await;
2413
2414 let (_workspace, _thread_store, thread, context_store) =
2415 setup_test_environment(cx, project.clone()).await;
2416
2417 // Open buffer and add it to context
2418 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2419 .await
2420 .unwrap();
2421
2422 let context =
2423 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2424
2425 // Insert user message with the buffer as context
2426 thread.update(cx, |thread, cx| {
2427 thread.insert_user_message("Explain this code", vec![context], None, cx)
2428 });
2429
2430 // Create a request and check that it doesn't have a stale buffer warning yet
2431 let initial_request = thread.update(cx, |thread, cx| {
2432 thread.to_completion_request(RequestKind::Chat, cx)
2433 });
2434
2435 // Make sure we don't have a stale file warning yet
2436 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2437 msg.string_contents()
2438 .contains("These files changed since last read:")
2439 });
2440 assert!(
2441 !has_stale_warning,
2442 "Should not have stale buffer warning before buffer is modified"
2443 );
2444
2445 // Modify the buffer
2446 buffer.update(cx, |buffer, cx| {
2447 // Find a position at the end of line 1
2448 buffer.edit(
2449 [(1..1, "\n println!(\"Added a new line\");\n")],
2450 None,
2451 cx,
2452 );
2453 });
2454
2455 // Insert another user message without context
2456 thread.update(cx, |thread, cx| {
2457 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2458 });
2459
2460 // Create a new request and check for the stale buffer warning
2461 let new_request = thread.update(cx, |thread, cx| {
2462 thread.to_completion_request(RequestKind::Chat, cx)
2463 });
2464
2465 // We should have a stale file warning as the last message
2466 let last_message = new_request
2467 .messages
2468 .last()
2469 .expect("Request should have messages");
2470
2471 // The last message should be the stale buffer notification
2472 assert_eq!(last_message.role, Role::User);
2473
2474 // Check the exact content of the message
2475 let expected_content = "These files changed since last read:\n- code.rs\n";
2476 assert_eq!(
2477 last_message.string_contents(),
2478 expected_content,
2479 "Last message should be exactly the stale buffer notification"
2480 );
2481 }
2482
2483 fn init_test_settings(cx: &mut TestAppContext) {
2484 cx.update(|cx| {
2485 let settings_store = SettingsStore::test(cx);
2486 cx.set_global(settings_store);
2487 language::init(cx);
2488 Project::init_settings(cx);
2489 AssistantSettings::register(cx);
2490 prompt_store::init(cx);
2491 thread_store::init(cx);
2492 workspace::init_settings(cx);
2493 ThemeSettings::register(cx);
2494 ContextServerSettings::register(cx);
2495 EditorSettings::register(cx);
2496 });
2497 }
2498
2499 // Helper to create a test project with test files
2500 async fn create_test_project(
2501 cx: &mut TestAppContext,
2502 files: serde_json::Value,
2503 ) -> Entity<Project> {
2504 let fs = FakeFs::new(cx.executor());
2505 fs.insert_tree(path!("/test"), files).await;
2506 Project::test(fs, [path!("/test").as_ref()], cx).await
2507 }
2508
2509 async fn setup_test_environment(
2510 cx: &mut TestAppContext,
2511 project: Entity<Project>,
2512 ) -> (
2513 Entity<Workspace>,
2514 Entity<ThreadStore>,
2515 Entity<Thread>,
2516 Entity<ContextStore>,
2517 ) {
2518 let (workspace, cx) =
2519 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2520
2521 let thread_store = cx
2522 .update(|_, cx| {
2523 ThreadStore::load(
2524 project.clone(),
2525 cx.new(|_| ToolWorkingSet::default()),
2526 Arc::new(PromptBuilder::new(None).unwrap()),
2527 cx,
2528 )
2529 })
2530 .await
2531 .unwrap();
2532
2533 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2534 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2535
2536 (workspace, thread_store, thread, context_store)
2537 }
2538
2539 async fn add_file_to_context(
2540 project: &Entity<Project>,
2541 context_store: &Entity<ContextStore>,
2542 path: &str,
2543 cx: &mut TestAppContext,
2544 ) -> Result<Entity<language::Buffer>> {
2545 let buffer_path = project
2546 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2547 .unwrap();
2548
2549 let buffer = project
2550 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2551 .await
2552 .unwrap();
2553
2554 context_store
2555 .update(cx, |store, cx| {
2556 store.add_file_from_buffer(buffer.clone(), cx)
2557 })
2558 .await?;
2559
2560 Ok(buffer)
2561 }
2562}