1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
26 StopReason, TokenUsage,
27};
28use postage::stream::Stream as _;
29use project::Project;
30use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
31use prompt_store::{ModelContext, PromptBuilder};
32use proto::Plan;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use thiserror::Error;
37use util::{ResultExt as _, TryFutureExt as _, post_inc};
38use uuid::Uuid;
39use zed_llm_client::CompletionMode;
40
41use crate::ThreadStore;
42use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
43use crate::thread_store::{
44 SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
45 SerializedToolResult, SerializedToolUse, SharedProjectContext,
46};
47use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72/// The ID of the user prompt that initiated a request.
73///
74/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
75#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
76pub struct PromptId(Arc<str>);
77
78impl PromptId {
79 pub fn new() -> Self {
80 Self(Uuid::new_v4().to_string().into())
81 }
82}
83
84impl std::fmt::Display for PromptId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
91pub struct MessageId(pub(crate) usize);
92
93impl MessageId {
94 fn post_inc(&mut self) -> Self {
95 Self(post_inc(&mut self.0))
96 }
97}
98
99/// A message in a [`Thread`].
100#[derive(Debug, Clone)]
101pub struct Message {
102 pub id: MessageId,
103 pub role: Role,
104 pub segments: Vec<MessageSegment>,
105 pub loaded_context: LoadedContext,
106}
107
108impl Message {
109 /// Returns whether the message contains any meaningful text that should be displayed
110 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
111 pub fn should_display_content(&self) -> bool {
112 self.segments.iter().all(|segment| segment.should_display())
113 }
114
115 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
116 if let Some(MessageSegment::Thinking {
117 text: segment,
118 signature: current_signature,
119 }) = self.segments.last_mut()
120 {
121 if let Some(signature) = signature {
122 *current_signature = Some(signature);
123 }
124 segment.push_str(text);
125 } else {
126 self.segments.push(MessageSegment::Thinking {
127 text: text.to_string(),
128 signature,
129 });
130 }
131 }
132
133 pub fn push_text(&mut self, text: &str) {
134 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Text(text.to_string()));
138 }
139 }
140
141 pub fn to_string(&self) -> String {
142 let mut result = String::new();
143
144 if !self.loaded_context.text.is_empty() {
145 result.push_str(&self.loaded_context.text);
146 }
147
148 for segment in &self.segments {
149 match segment {
150 MessageSegment::Text(text) => result.push_str(text),
151 MessageSegment::Thinking { text, .. } => {
152 result.push_str("<think>\n");
153 result.push_str(text);
154 result.push_str("\n</think>");
155 }
156 MessageSegment::RedactedThinking(_) => {}
157 }
158 }
159
160 result
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum MessageSegment {
166 Text(String),
167 Thinking {
168 text: String,
169 signature: Option<String>,
170 },
171 RedactedThinking(Vec<u8>),
172}
173
174impl MessageSegment {
175 pub fn should_display(&self) -> bool {
176 match self {
177 Self::Text(text) => text.is_empty(),
178 Self::Thinking { text, .. } => text.is_empty(),
179 Self::RedactedThinking(_) => false,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ProjectSnapshot {
186 pub worktree_snapshots: Vec<WorktreeSnapshot>,
187 pub unsaved_buffer_paths: Vec<String>,
188 pub timestamp: DateTime<Utc>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct WorktreeSnapshot {
193 pub worktree_path: String,
194 pub git_state: Option<GitState>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct GitState {
199 pub remote_url: Option<String>,
200 pub head_sha: Option<String>,
201 pub current_branch: Option<String>,
202 pub diff: Option<String>,
203}
204
205#[derive(Clone)]
206pub struct ThreadCheckpoint {
207 message_id: MessageId,
208 git_checkpoint: GitStoreCheckpoint,
209}
210
211#[derive(Copy, Clone, Debug, PartialEq, Eq)]
212pub enum ThreadFeedback {
213 Positive,
214 Negative,
215}
216
217pub enum LastRestoreCheckpoint {
218 Pending {
219 message_id: MessageId,
220 },
221 Error {
222 message_id: MessageId,
223 error: String,
224 },
225}
226
227impl LastRestoreCheckpoint {
228 pub fn message_id(&self) -> MessageId {
229 match self {
230 LastRestoreCheckpoint::Pending { message_id } => *message_id,
231 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
232 }
233 }
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub enum DetailedSummaryState {
238 #[default]
239 NotGenerated,
240 Generating {
241 message_id: MessageId,
242 },
243 Generated {
244 text: SharedString,
245 message_id: MessageId,
246 },
247}
248
249impl DetailedSummaryState {
250 fn text(&self) -> Option<SharedString> {
251 if let Self::Generated { text, .. } = self {
252 Some(text.clone())
253 } else {
254 None
255 }
256 }
257}
258
259#[derive(Default)]
260pub struct TotalTokenUsage {
261 pub total: usize,
262 pub max: usize,
263}
264
265impl TotalTokenUsage {
266 pub fn ratio(&self) -> TokenUsageRatio {
267 #[cfg(debug_assertions)]
268 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
269 .unwrap_or("0.8".to_string())
270 .parse()
271 .unwrap();
272 #[cfg(not(debug_assertions))]
273 let warning_threshold: f32 = 0.8;
274
275 // When the maximum is unknown because there is no selected model,
276 // avoid showing the token limit warning.
277 if self.max == 0 {
278 TokenUsageRatio::Normal
279 } else if self.total >= self.max {
280 TokenUsageRatio::Exceeded
281 } else if self.total as f32 / self.max as f32 >= warning_threshold {
282 TokenUsageRatio::Warning
283 } else {
284 TokenUsageRatio::Normal
285 }
286 }
287
288 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
289 TotalTokenUsage {
290 total: self.total + tokens,
291 max: self.max,
292 }
293 }
294}
295
296#[derive(Debug, Default, PartialEq, Eq)]
297pub enum TokenUsageRatio {
298 #[default]
299 Normal,
300 Warning,
301 Exceeded,
302}
303
304fn default_completion_mode(cx: &App) -> CompletionMode {
305 if cx.is_staff() {
306 CompletionMode::Max
307 } else {
308 CompletionMode::Normal
309 }
310}
311
312/// A thread of conversation with the LLM.
313pub struct Thread {
314 id: ThreadId,
315 updated_at: DateTime<Utc>,
316 summary: Option<SharedString>,
317 pending_summary: Task<Option<()>>,
318 detailed_summary_task: Task<Option<()>>,
319 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
320 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
321 completion_mode: CompletionMode,
322 messages: Vec<Message>,
323 next_message_id: MessageId,
324 last_prompt_id: PromptId,
325 project_context: SharedProjectContext,
326 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
327 completion_count: usize,
328 pending_completions: Vec<PendingCompletion>,
329 project: Entity<Project>,
330 prompt_builder: Arc<PromptBuilder>,
331 tools: Entity<ToolWorkingSet>,
332 tool_use: ToolUseState,
333 action_log: Entity<ActionLog>,
334 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
335 pending_checkpoint: Option<ThreadCheckpoint>,
336 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
337 request_token_usage: Vec<TokenUsage>,
338 cumulative_token_usage: TokenUsage,
339 exceeded_window_error: Option<ExceededWindowError>,
340 feedback: Option<ThreadFeedback>,
341 message_feedback: HashMap<MessageId, ThreadFeedback>,
342 last_auto_capture_at: Option<Instant>,
343 request_callback: Option<
344 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
345 >,
346 remaining_turns: u32,
347 configured_model: Option<ConfiguredModel>,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct ExceededWindowError {
352 /// Model used when last message exceeded context window
353 model_id: LanguageModelId,
354 /// Token count including last message
355 token_count: usize,
356}
357
358impl Thread {
359 pub fn new(
360 project: Entity<Project>,
361 tools: Entity<ToolWorkingSet>,
362 prompt_builder: Arc<PromptBuilder>,
363 system_prompt: SharedProjectContext,
364 cx: &mut Context<Self>,
365 ) -> Self {
366 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
367 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
368
369 Self {
370 id: ThreadId::new(),
371 updated_at: Utc::now(),
372 summary: None,
373 pending_summary: Task::ready(None),
374 detailed_summary_task: Task::ready(None),
375 detailed_summary_tx,
376 detailed_summary_rx,
377 completion_mode: default_completion_mode(cx),
378 messages: Vec::new(),
379 next_message_id: MessageId(0),
380 last_prompt_id: PromptId::new(),
381 project_context: system_prompt,
382 checkpoints_by_message: HashMap::default(),
383 completion_count: 0,
384 pending_completions: Vec::new(),
385 project: project.clone(),
386 prompt_builder,
387 tools: tools.clone(),
388 last_restore_checkpoint: None,
389 pending_checkpoint: None,
390 tool_use: ToolUseState::new(tools.clone()),
391 action_log: cx.new(|_| ActionLog::new(project.clone())),
392 initial_project_snapshot: {
393 let project_snapshot = Self::project_snapshot(project, cx);
394 cx.foreground_executor()
395 .spawn(async move { Some(project_snapshot.await) })
396 .shared()
397 },
398 request_token_usage: Vec::new(),
399 cumulative_token_usage: TokenUsage::default(),
400 exceeded_window_error: None,
401 feedback: None,
402 message_feedback: HashMap::default(),
403 last_auto_capture_at: None,
404 request_callback: None,
405 remaining_turns: u32::MAX,
406 configured_model,
407 }
408 }
409
410 pub fn deserialize(
411 id: ThreadId,
412 serialized: SerializedThread,
413 project: Entity<Project>,
414 tools: Entity<ToolWorkingSet>,
415 prompt_builder: Arc<PromptBuilder>,
416 project_context: SharedProjectContext,
417 cx: &mut Context<Self>,
418 ) -> Self {
419 let next_message_id = MessageId(
420 serialized
421 .messages
422 .last()
423 .map(|message| message.id.0 + 1)
424 .unwrap_or(0),
425 );
426 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
427 let (detailed_summary_tx, detailed_summary_rx) =
428 postage::watch::channel_with(serialized.detailed_summary_state);
429
430 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
431 serialized
432 .model
433 .and_then(|model| {
434 let model = SelectedModel {
435 provider: model.provider.clone().into(),
436 model: model.model.clone().into(),
437 };
438 registry.select_model(&model, cx)
439 })
440 .or_else(|| registry.default_model())
441 });
442
443 Self {
444 id,
445 updated_at: serialized.updated_at,
446 summary: Some(serialized.summary),
447 pending_summary: Task::ready(None),
448 detailed_summary_task: Task::ready(None),
449 detailed_summary_tx,
450 detailed_summary_rx,
451 completion_mode: default_completion_mode(cx),
452 messages: serialized
453 .messages
454 .into_iter()
455 .map(|message| Message {
456 id: message.id,
457 role: message.role,
458 segments: message
459 .segments
460 .into_iter()
461 .map(|segment| match segment {
462 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
463 SerializedMessageSegment::Thinking { text, signature } => {
464 MessageSegment::Thinking { text, signature }
465 }
466 SerializedMessageSegment::RedactedThinking { data } => {
467 MessageSegment::RedactedThinking(data)
468 }
469 })
470 .collect(),
471 loaded_context: LoadedContext {
472 contexts: Vec::new(),
473 text: message.context,
474 images: Vec::new(),
475 },
476 })
477 .collect(),
478 next_message_id,
479 last_prompt_id: PromptId::new(),
480 project_context,
481 checkpoints_by_message: HashMap::default(),
482 completion_count: 0,
483 pending_completions: Vec::new(),
484 last_restore_checkpoint: None,
485 pending_checkpoint: None,
486 project: project.clone(),
487 prompt_builder,
488 tools,
489 tool_use,
490 action_log: cx.new(|_| ActionLog::new(project)),
491 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
492 request_token_usage: serialized.request_token_usage,
493 cumulative_token_usage: serialized.cumulative_token_usage,
494 exceeded_window_error: None,
495 feedback: None,
496 message_feedback: HashMap::default(),
497 last_auto_capture_at: None,
498 request_callback: None,
499 remaining_turns: u32::MAX,
500 configured_model,
501 }
502 }
503
504 pub fn set_request_callback(
505 &mut self,
506 callback: impl 'static
507 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
508 ) {
509 self.request_callback = Some(Box::new(callback));
510 }
511
512 pub fn id(&self) -> &ThreadId {
513 &self.id
514 }
515
516 pub fn is_empty(&self) -> bool {
517 self.messages.is_empty()
518 }
519
520 pub fn updated_at(&self) -> DateTime<Utc> {
521 self.updated_at
522 }
523
524 pub fn touch_updated_at(&mut self) {
525 self.updated_at = Utc::now();
526 }
527
528 pub fn advance_prompt_id(&mut self) {
529 self.last_prompt_id = PromptId::new();
530 }
531
532 pub fn summary(&self) -> Option<SharedString> {
533 self.summary.clone()
534 }
535
536 pub fn project_context(&self) -> SharedProjectContext {
537 self.project_context.clone()
538 }
539
540 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
541 if self.configured_model.is_none() {
542 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
543 }
544 self.configured_model.clone()
545 }
546
547 pub fn configured_model(&self) -> Option<ConfiguredModel> {
548 self.configured_model.clone()
549 }
550
551 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
552 self.configured_model = model;
553 cx.notify();
554 }
555
556 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
557
558 pub fn summary_or_default(&self) -> SharedString {
559 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
560 }
561
562 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
563 let Some(current_summary) = &self.summary else {
564 // Don't allow setting summary until generated
565 return;
566 };
567
568 let mut new_summary = new_summary.into();
569
570 if new_summary.is_empty() {
571 new_summary = Self::DEFAULT_SUMMARY;
572 }
573
574 if current_summary != &new_summary {
575 self.summary = Some(new_summary);
576 cx.emit(ThreadEvent::SummaryChanged);
577 }
578 }
579
580 pub fn completion_mode(&self) -> CompletionMode {
581 self.completion_mode
582 }
583
584 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
585 self.completion_mode = mode;
586 }
587
588 pub fn message(&self, id: MessageId) -> Option<&Message> {
589 let index = self
590 .messages
591 .binary_search_by(|message| message.id.cmp(&id))
592 .ok()?;
593
594 self.messages.get(index)
595 }
596
597 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
598 self.messages.iter()
599 }
600
601 pub fn is_generating(&self) -> bool {
602 !self.pending_completions.is_empty() || !self.all_tools_finished()
603 }
604
605 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
606 &self.tools
607 }
608
609 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
610 self.tool_use
611 .pending_tool_uses()
612 .into_iter()
613 .find(|tool_use| &tool_use.id == id)
614 }
615
616 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
617 self.tool_use
618 .pending_tool_uses()
619 .into_iter()
620 .filter(|tool_use| tool_use.status.needs_confirmation())
621 }
622
623 pub fn has_pending_tool_uses(&self) -> bool {
624 !self.tool_use.pending_tool_uses().is_empty()
625 }
626
627 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
628 self.checkpoints_by_message.get(&id).cloned()
629 }
630
631 pub fn restore_checkpoint(
632 &mut self,
633 checkpoint: ThreadCheckpoint,
634 cx: &mut Context<Self>,
635 ) -> Task<Result<()>> {
636 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
637 message_id: checkpoint.message_id,
638 });
639 cx.emit(ThreadEvent::CheckpointChanged);
640 cx.notify();
641
642 let git_store = self.project().read(cx).git_store().clone();
643 let restore = git_store.update(cx, |git_store, cx| {
644 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
645 });
646
647 cx.spawn(async move |this, cx| {
648 let result = restore.await;
649 this.update(cx, |this, cx| {
650 if let Err(err) = result.as_ref() {
651 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
652 message_id: checkpoint.message_id,
653 error: err.to_string(),
654 });
655 } else {
656 this.truncate(checkpoint.message_id, cx);
657 this.last_restore_checkpoint = None;
658 }
659 this.pending_checkpoint = None;
660 cx.emit(ThreadEvent::CheckpointChanged);
661 cx.notify();
662 })?;
663 result
664 })
665 }
666
667 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
668 let pending_checkpoint = if self.is_generating() {
669 return;
670 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
671 checkpoint
672 } else {
673 return;
674 };
675
676 let git_store = self.project.read(cx).git_store().clone();
677 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
678 cx.spawn(async move |this, cx| match final_checkpoint.await {
679 Ok(final_checkpoint) => {
680 let equal = git_store
681 .update(cx, |store, cx| {
682 store.compare_checkpoints(
683 pending_checkpoint.git_checkpoint.clone(),
684 final_checkpoint.clone(),
685 cx,
686 )
687 })?
688 .await
689 .unwrap_or(false);
690
691 if !equal {
692 this.update(cx, |this, cx| {
693 this.insert_checkpoint(pending_checkpoint, cx)
694 })?;
695 }
696
697 Ok(())
698 }
699 Err(_) => this.update(cx, |this, cx| {
700 this.insert_checkpoint(pending_checkpoint, cx)
701 }),
702 })
703 .detach();
704 }
705
706 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
707 self.checkpoints_by_message
708 .insert(checkpoint.message_id, checkpoint);
709 cx.emit(ThreadEvent::CheckpointChanged);
710 cx.notify();
711 }
712
713 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
714 self.last_restore_checkpoint.as_ref()
715 }
716
717 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
718 let Some(message_ix) = self
719 .messages
720 .iter()
721 .rposition(|message| message.id == message_id)
722 else {
723 return;
724 };
725 for deleted_message in self.messages.drain(message_ix..) {
726 self.checkpoints_by_message.remove(&deleted_message.id);
727 }
728 cx.notify();
729 }
730
731 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
732 self.messages
733 .iter()
734 .find(|message| message.id == id)
735 .into_iter()
736 .flat_map(|message| message.loaded_context.contexts.iter())
737 }
738
739 pub fn is_turn_end(&self, ix: usize) -> bool {
740 if self.messages.is_empty() {
741 return false;
742 }
743
744 if !self.is_generating() && ix == self.messages.len() - 1 {
745 return true;
746 }
747
748 let Some(message) = self.messages.get(ix) else {
749 return false;
750 };
751
752 if message.role != Role::Assistant {
753 return false;
754 }
755
756 self.messages
757 .get(ix + 1)
758 .and_then(|message| {
759 self.message(message.id)
760 .map(|next_message| next_message.role == Role::User)
761 })
762 .unwrap_or(false)
763 }
764
765 /// Returns whether all of the tool uses have finished running.
766 pub fn all_tools_finished(&self) -> bool {
767 // If the only pending tool uses left are the ones with errors, then
768 // that means that we've finished running all of the pending tools.
769 self.tool_use
770 .pending_tool_uses()
771 .iter()
772 .all(|tool_use| tool_use.status.is_error())
773 }
774
775 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
776 self.tool_use.tool_uses_for_message(id, cx)
777 }
778
779 pub fn tool_results_for_message(
780 &self,
781 assistant_message_id: MessageId,
782 ) -> Vec<&LanguageModelToolResult> {
783 self.tool_use.tool_results_for_message(assistant_message_id)
784 }
785
786 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
787 self.tool_use.tool_result(id)
788 }
789
790 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
791 Some(&self.tool_use.tool_result(id)?.content)
792 }
793
794 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
795 self.tool_use.tool_result_card(id).cloned()
796 }
797
798 /// Return tools that are both enabled and supported by the model
799 pub fn available_tools(
800 &self,
801 cx: &App,
802 model: Arc<dyn LanguageModel>,
803 ) -> Vec<LanguageModelRequestTool> {
804 if model.supports_tools() {
805 self.tools()
806 .read(cx)
807 .enabled_tools(cx)
808 .into_iter()
809 .filter_map(|tool| {
810 // Skip tools that cannot be supported
811 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
812 Some(LanguageModelRequestTool {
813 name: tool.name(),
814 description: tool.description(),
815 input_schema,
816 })
817 })
818 .collect()
819 } else {
820 Vec::default()
821 }
822 }
823
824 pub fn insert_user_message(
825 &mut self,
826 text: impl Into<String>,
827 loaded_context: ContextLoadResult,
828 git_checkpoint: Option<GitStoreCheckpoint>,
829 cx: &mut Context<Self>,
830 ) -> MessageId {
831 if !loaded_context.referenced_buffers.is_empty() {
832 self.action_log.update(cx, |log, cx| {
833 for buffer in loaded_context.referenced_buffers {
834 log.track_buffer(buffer, cx);
835 }
836 });
837 }
838
839 let message_id = self.insert_message(
840 Role::User,
841 vec![MessageSegment::Text(text.into())],
842 loaded_context.loaded_context,
843 cx,
844 );
845
846 if let Some(git_checkpoint) = git_checkpoint {
847 self.pending_checkpoint = Some(ThreadCheckpoint {
848 message_id,
849 git_checkpoint,
850 });
851 }
852
853 self.auto_capture_telemetry(cx);
854
855 message_id
856 }
857
858 pub fn insert_assistant_message(
859 &mut self,
860 segments: Vec<MessageSegment>,
861 cx: &mut Context<Self>,
862 ) -> MessageId {
863 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
864 }
865
866 pub fn insert_message(
867 &mut self,
868 role: Role,
869 segments: Vec<MessageSegment>,
870 loaded_context: LoadedContext,
871 cx: &mut Context<Self>,
872 ) -> MessageId {
873 let id = self.next_message_id.post_inc();
874 self.messages.push(Message {
875 id,
876 role,
877 segments,
878 loaded_context,
879 });
880 self.touch_updated_at();
881 cx.emit(ThreadEvent::MessageAdded(id));
882 id
883 }
884
885 pub fn edit_message(
886 &mut self,
887 id: MessageId,
888 new_role: Role,
889 new_segments: Vec<MessageSegment>,
890 loaded_context: Option<LoadedContext>,
891 cx: &mut Context<Self>,
892 ) -> bool {
893 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
894 return false;
895 };
896 message.role = new_role;
897 message.segments = new_segments;
898 if let Some(context) = loaded_context {
899 message.loaded_context = context;
900 }
901 self.touch_updated_at();
902 cx.emit(ThreadEvent::MessageEdited(id));
903 true
904 }
905
906 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
907 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
908 return false;
909 };
910 self.messages.remove(index);
911 self.touch_updated_at();
912 cx.emit(ThreadEvent::MessageDeleted(id));
913 true
914 }
915
916 /// Returns the representation of this [`Thread`] in a textual form.
917 ///
918 /// This is the representation we use when attaching a thread as context to another thread.
919 pub fn text(&self) -> String {
920 let mut text = String::new();
921
922 for message in &self.messages {
923 text.push_str(match message.role {
924 language_model::Role::User => "User:",
925 language_model::Role::Assistant => "Assistant:",
926 language_model::Role::System => "System:",
927 });
928 text.push('\n');
929
930 for segment in &message.segments {
931 match segment {
932 MessageSegment::Text(content) => text.push_str(content),
933 MessageSegment::Thinking { text: content, .. } => {
934 text.push_str(&format!("<think>{}</think>", content))
935 }
936 MessageSegment::RedactedThinking(_) => {}
937 }
938 }
939 text.push('\n');
940 }
941
942 text
943 }
944
945 /// Serializes this thread into a format for storage or telemetry.
946 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
947 let initial_project_snapshot = self.initial_project_snapshot.clone();
948 cx.spawn(async move |this, cx| {
949 let initial_project_snapshot = initial_project_snapshot.await;
950 this.read_with(cx, |this, cx| SerializedThread {
951 version: SerializedThread::VERSION.to_string(),
952 summary: this.summary_or_default(),
953 updated_at: this.updated_at(),
954 messages: this
955 .messages()
956 .map(|message| SerializedMessage {
957 id: message.id,
958 role: message.role,
959 segments: message
960 .segments
961 .iter()
962 .map(|segment| match segment {
963 MessageSegment::Text(text) => {
964 SerializedMessageSegment::Text { text: text.clone() }
965 }
966 MessageSegment::Thinking { text, signature } => {
967 SerializedMessageSegment::Thinking {
968 text: text.clone(),
969 signature: signature.clone(),
970 }
971 }
972 MessageSegment::RedactedThinking(data) => {
973 SerializedMessageSegment::RedactedThinking {
974 data: data.clone(),
975 }
976 }
977 })
978 .collect(),
979 tool_uses: this
980 .tool_uses_for_message(message.id, cx)
981 .into_iter()
982 .map(|tool_use| SerializedToolUse {
983 id: tool_use.id,
984 name: tool_use.name,
985 input: tool_use.input,
986 })
987 .collect(),
988 tool_results: this
989 .tool_results_for_message(message.id)
990 .into_iter()
991 .map(|tool_result| SerializedToolResult {
992 tool_use_id: tool_result.tool_use_id.clone(),
993 is_error: tool_result.is_error,
994 content: tool_result.content.clone(),
995 })
996 .collect(),
997 context: message.loaded_context.text.clone(),
998 })
999 .collect(),
1000 initial_project_snapshot,
1001 cumulative_token_usage: this.cumulative_token_usage,
1002 request_token_usage: this.request_token_usage.clone(),
1003 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1004 exceeded_window_error: this.exceeded_window_error.clone(),
1005 model: this
1006 .configured_model
1007 .as_ref()
1008 .map(|model| SerializedLanguageModel {
1009 provider: model.provider.id().0.to_string(),
1010 model: model.model.id().0.to_string(),
1011 }),
1012 })
1013 })
1014 }
1015
1016 pub fn remaining_turns(&self) -> u32 {
1017 self.remaining_turns
1018 }
1019
1020 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1021 self.remaining_turns = remaining_turns;
1022 }
1023
1024 pub fn send_to_model(
1025 &mut self,
1026 model: Arc<dyn LanguageModel>,
1027 window: Option<AnyWindowHandle>,
1028 cx: &mut Context<Self>,
1029 ) {
1030 if self.remaining_turns == 0 {
1031 return;
1032 }
1033
1034 self.remaining_turns -= 1;
1035
1036 let request = self.to_completion_request(model.clone(), cx);
1037
1038 self.stream_completion(request, model, window, cx);
1039 }
1040
1041 pub fn used_tools_since_last_user_message(&self) -> bool {
1042 for message in self.messages.iter().rev() {
1043 if self.tool_use.message_has_tool_results(message.id) {
1044 return true;
1045 } else if message.role == Role::User {
1046 return false;
1047 }
1048 }
1049
1050 false
1051 }
1052
1053 pub fn to_completion_request(
1054 &self,
1055 model: Arc<dyn LanguageModel>,
1056 cx: &mut Context<Self>,
1057 ) -> LanguageModelRequest {
1058 let mut request = LanguageModelRequest {
1059 thread_id: Some(self.id.to_string()),
1060 prompt_id: Some(self.last_prompt_id.to_string()),
1061 mode: None,
1062 messages: vec![],
1063 tools: Vec::new(),
1064 stop: Vec::new(),
1065 temperature: None,
1066 };
1067
1068 let available_tools = self.available_tools(cx, model.clone());
1069 let available_tool_names = available_tools
1070 .iter()
1071 .map(|tool| tool.name.clone())
1072 .collect();
1073
1074 let model_context = &ModelContext {
1075 available_tools: available_tool_names,
1076 };
1077
1078 if let Some(project_context) = self.project_context.borrow().as_ref() {
1079 match self
1080 .prompt_builder
1081 .generate_assistant_system_prompt(project_context, model_context)
1082 {
1083 Err(err) => {
1084 let message = format!("{err:?}").into();
1085 log::error!("{message}");
1086 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1087 header: "Error generating system prompt".into(),
1088 message,
1089 }));
1090 }
1091 Ok(system_prompt) => {
1092 request.messages.push(LanguageModelRequestMessage {
1093 role: Role::System,
1094 content: vec![MessageContent::Text(system_prompt)],
1095 cache: true,
1096 });
1097 }
1098 }
1099 } else {
1100 let message = "Context for system prompt unexpectedly not ready.".into();
1101 log::error!("{message}");
1102 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1103 header: "Error generating system prompt".into(),
1104 message,
1105 }));
1106 }
1107
1108 for message in &self.messages {
1109 let mut request_message = LanguageModelRequestMessage {
1110 role: message.role,
1111 content: Vec::new(),
1112 cache: false,
1113 };
1114
1115 message
1116 .loaded_context
1117 .add_to_request_message(&mut request_message);
1118
1119 for segment in &message.segments {
1120 match segment {
1121 MessageSegment::Text(text) => {
1122 if !text.is_empty() {
1123 request_message
1124 .content
1125 .push(MessageContent::Text(text.into()));
1126 }
1127 }
1128 MessageSegment::Thinking { text, signature } => {
1129 if !text.is_empty() {
1130 request_message.content.push(MessageContent::Thinking {
1131 text: text.into(),
1132 signature: signature.clone(),
1133 });
1134 }
1135 }
1136 MessageSegment::RedactedThinking(data) => {
1137 request_message
1138 .content
1139 .push(MessageContent::RedactedThinking(data.clone()));
1140 }
1141 };
1142 }
1143
1144 self.tool_use
1145 .attach_tool_uses(message.id, &mut request_message);
1146
1147 request.messages.push(request_message);
1148
1149 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1150 request.messages.push(tool_results_message);
1151 }
1152 }
1153
1154 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1155 if let Some(last) = request.messages.last_mut() {
1156 last.cache = true;
1157 }
1158
1159 self.attached_tracked_files_state(&mut request.messages, cx);
1160
1161 request.tools = available_tools;
1162 request.mode = if model.supports_max_mode() {
1163 Some(self.completion_mode)
1164 } else {
1165 Some(CompletionMode::Normal)
1166 };
1167
1168 request
1169 }
1170
1171 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1172 let mut request = LanguageModelRequest {
1173 thread_id: None,
1174 prompt_id: None,
1175 mode: None,
1176 messages: vec![],
1177 tools: Vec::new(),
1178 stop: Vec::new(),
1179 temperature: None,
1180 };
1181
1182 for message in &self.messages {
1183 let mut request_message = LanguageModelRequestMessage {
1184 role: message.role,
1185 content: Vec::new(),
1186 cache: false,
1187 };
1188
1189 for segment in &message.segments {
1190 match segment {
1191 MessageSegment::Text(text) => request_message
1192 .content
1193 .push(MessageContent::Text(text.clone())),
1194 MessageSegment::Thinking { .. } => {}
1195 MessageSegment::RedactedThinking(_) => {}
1196 }
1197 }
1198
1199 if request_message.content.is_empty() {
1200 continue;
1201 }
1202
1203 request.messages.push(request_message);
1204 }
1205
1206 request.messages.push(LanguageModelRequestMessage {
1207 role: Role::User,
1208 content: vec![MessageContent::Text(added_user_message)],
1209 cache: false,
1210 });
1211
1212 request
1213 }
1214
1215 fn attached_tracked_files_state(
1216 &self,
1217 messages: &mut Vec<LanguageModelRequestMessage>,
1218 cx: &App,
1219 ) {
1220 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1221
1222 let mut stale_message = String::new();
1223
1224 let action_log = self.action_log.read(cx);
1225
1226 for stale_file in action_log.stale_buffers(cx) {
1227 let Some(file) = stale_file.read(cx).file() else {
1228 continue;
1229 };
1230
1231 if stale_message.is_empty() {
1232 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1233 }
1234
1235 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1236 }
1237
1238 let mut content = Vec::with_capacity(2);
1239
1240 if !stale_message.is_empty() {
1241 content.push(stale_message.into());
1242 }
1243
1244 if !content.is_empty() {
1245 let context_message = LanguageModelRequestMessage {
1246 role: Role::User,
1247 content,
1248 cache: false,
1249 };
1250
1251 messages.push(context_message);
1252 }
1253 }
1254
1255 pub fn stream_completion(
1256 &mut self,
1257 request: LanguageModelRequest,
1258 model: Arc<dyn LanguageModel>,
1259 window: Option<AnyWindowHandle>,
1260 cx: &mut Context<Self>,
1261 ) {
1262 let pending_completion_id = post_inc(&mut self.completion_count);
1263 let mut request_callback_parameters = if self.request_callback.is_some() {
1264 Some((request.clone(), Vec::new()))
1265 } else {
1266 None
1267 };
1268 let prompt_id = self.last_prompt_id.clone();
1269 let tool_use_metadata = ToolUseMetadata {
1270 model: model.clone(),
1271 thread_id: self.id.clone(),
1272 prompt_id: prompt_id.clone(),
1273 };
1274
1275 let task = cx.spawn(async move |thread, cx| {
1276 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1277 let initial_token_usage =
1278 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1279 let stream_completion = async {
1280 let (mut events, usage) = stream_completion_future.await?;
1281
1282 let mut stop_reason = StopReason::EndTurn;
1283 let mut current_token_usage = TokenUsage::default();
1284
1285 if let Some(usage) = usage {
1286 thread
1287 .update(cx, |_thread, cx| {
1288 cx.emit(ThreadEvent::UsageUpdated(usage));
1289 })
1290 .ok();
1291 }
1292
1293 let mut request_assistant_message_id = None;
1294
1295 while let Some(event) = events.next().await {
1296 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1297 response_events
1298 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1299 }
1300
1301 thread.update(cx, |thread, cx| {
1302 let event = match event {
1303 Ok(event) => event,
1304 Err(LanguageModelCompletionError::BadInputJson {
1305 id,
1306 tool_name,
1307 raw_input: invalid_input_json,
1308 json_parse_error,
1309 }) => {
1310 thread.receive_invalid_tool_json(
1311 id,
1312 tool_name,
1313 invalid_input_json,
1314 json_parse_error,
1315 window,
1316 cx,
1317 );
1318 return Ok(());
1319 }
1320 Err(LanguageModelCompletionError::Other(error)) => {
1321 return Err(error);
1322 }
1323 };
1324
1325 match event {
1326 LanguageModelCompletionEvent::StartMessage { .. } => {
1327 request_assistant_message_id =
1328 Some(thread.insert_assistant_message(
1329 vec![MessageSegment::Text(String::new())],
1330 cx,
1331 ));
1332 }
1333 LanguageModelCompletionEvent::Stop(reason) => {
1334 stop_reason = reason;
1335 }
1336 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1337 thread.update_token_usage_at_last_message(token_usage);
1338 thread.cumulative_token_usage = thread.cumulative_token_usage
1339 + token_usage
1340 - current_token_usage;
1341 current_token_usage = token_usage;
1342 }
1343 LanguageModelCompletionEvent::Text(chunk) => {
1344 cx.emit(ThreadEvent::ReceivedTextChunk);
1345 if let Some(last_message) = thread.messages.last_mut() {
1346 if last_message.role == Role::Assistant
1347 && !thread.tool_use.has_tool_results(last_message.id)
1348 {
1349 last_message.push_text(&chunk);
1350 cx.emit(ThreadEvent::StreamedAssistantText(
1351 last_message.id,
1352 chunk,
1353 ));
1354 } else {
1355 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1356 // of a new Assistant response.
1357 //
1358 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1359 // will result in duplicating the text of the chunk in the rendered Markdown.
1360 request_assistant_message_id =
1361 Some(thread.insert_assistant_message(
1362 vec![MessageSegment::Text(chunk.to_string())],
1363 cx,
1364 ));
1365 };
1366 }
1367 }
1368 LanguageModelCompletionEvent::Thinking {
1369 text: chunk,
1370 signature,
1371 } => {
1372 if let Some(last_message) = thread.messages.last_mut() {
1373 if last_message.role == Role::Assistant
1374 && !thread.tool_use.has_tool_results(last_message.id)
1375 {
1376 last_message.push_thinking(&chunk, signature);
1377 cx.emit(ThreadEvent::StreamedAssistantThinking(
1378 last_message.id,
1379 chunk,
1380 ));
1381 } else {
1382 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1383 // of a new Assistant response.
1384 //
1385 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1386 // will result in duplicating the text of the chunk in the rendered Markdown.
1387 request_assistant_message_id =
1388 Some(thread.insert_assistant_message(
1389 vec![MessageSegment::Thinking {
1390 text: chunk.to_string(),
1391 signature,
1392 }],
1393 cx,
1394 ));
1395 };
1396 }
1397 }
1398 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1399 let last_assistant_message_id = request_assistant_message_id
1400 .unwrap_or_else(|| {
1401 let new_assistant_message_id =
1402 thread.insert_assistant_message(vec![], cx);
1403 request_assistant_message_id =
1404 Some(new_assistant_message_id);
1405 new_assistant_message_id
1406 });
1407
1408 let tool_use_id = tool_use.id.clone();
1409 let streamed_input = if tool_use.is_input_complete {
1410 None
1411 } else {
1412 Some((&tool_use.input).clone())
1413 };
1414
1415 let ui_text = thread.tool_use.request_tool_use(
1416 last_assistant_message_id,
1417 tool_use,
1418 tool_use_metadata.clone(),
1419 cx,
1420 );
1421
1422 if let Some(input) = streamed_input {
1423 cx.emit(ThreadEvent::StreamedToolUse {
1424 tool_use_id,
1425 ui_text,
1426 input,
1427 });
1428 }
1429 }
1430 }
1431
1432 thread.touch_updated_at();
1433 cx.emit(ThreadEvent::StreamedCompletion);
1434 cx.notify();
1435
1436 thread.auto_capture_telemetry(cx);
1437 Ok(())
1438 })??;
1439
1440 smol::future::yield_now().await;
1441 }
1442
1443 thread.update(cx, |thread, cx| {
1444 thread
1445 .pending_completions
1446 .retain(|completion| completion.id != pending_completion_id);
1447
1448 // If there is a response without tool use, summarize the message. Otherwise,
1449 // allow two tool uses before summarizing.
1450 if thread.summary.is_none()
1451 && thread.messages.len() >= 2
1452 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1453 {
1454 thread.summarize(cx);
1455 }
1456 })?;
1457
1458 anyhow::Ok(stop_reason)
1459 };
1460
1461 let result = stream_completion.await;
1462
1463 thread
1464 .update(cx, |thread, cx| {
1465 thread.finalize_pending_checkpoint(cx);
1466 match result.as_ref() {
1467 Ok(stop_reason) => match stop_reason {
1468 StopReason::ToolUse => {
1469 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1470 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1471 }
1472 StopReason::EndTurn => {}
1473 StopReason::MaxTokens => {}
1474 },
1475 Err(error) => {
1476 if error.is::<PaymentRequiredError>() {
1477 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1478 } else if error.is::<MaxMonthlySpendReachedError>() {
1479 cx.emit(ThreadEvent::ShowError(
1480 ThreadError::MaxMonthlySpendReached,
1481 ));
1482 } else if let Some(error) =
1483 error.downcast_ref::<ModelRequestLimitReachedError>()
1484 {
1485 cx.emit(ThreadEvent::ShowError(
1486 ThreadError::ModelRequestLimitReached { plan: error.plan },
1487 ));
1488 } else if let Some(known_error) =
1489 error.downcast_ref::<LanguageModelKnownError>()
1490 {
1491 match known_error {
1492 LanguageModelKnownError::ContextWindowLimitExceeded {
1493 tokens,
1494 } => {
1495 thread.exceeded_window_error = Some(ExceededWindowError {
1496 model_id: model.id(),
1497 token_count: *tokens,
1498 });
1499 cx.notify();
1500 }
1501 }
1502 } else {
1503 let error_message = error
1504 .chain()
1505 .map(|err| err.to_string())
1506 .collect::<Vec<_>>()
1507 .join("\n");
1508 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1509 header: "Error interacting with language model".into(),
1510 message: SharedString::from(error_message.clone()),
1511 }));
1512 }
1513
1514 thread.cancel_last_completion(window, cx);
1515 }
1516 }
1517 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1518
1519 if let Some((request_callback, (request, response_events))) = thread
1520 .request_callback
1521 .as_mut()
1522 .zip(request_callback_parameters.as_ref())
1523 {
1524 request_callback(request, response_events);
1525 }
1526
1527 thread.auto_capture_telemetry(cx);
1528
1529 if let Ok(initial_usage) = initial_token_usage {
1530 let usage = thread.cumulative_token_usage - initial_usage;
1531
1532 telemetry::event!(
1533 "Assistant Thread Completion",
1534 thread_id = thread.id().to_string(),
1535 prompt_id = prompt_id,
1536 model = model.telemetry_id(),
1537 model_provider = model.provider_id().to_string(),
1538 input_tokens = usage.input_tokens,
1539 output_tokens = usage.output_tokens,
1540 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1541 cache_read_input_tokens = usage.cache_read_input_tokens,
1542 );
1543 }
1544 })
1545 .ok();
1546 });
1547
1548 self.pending_completions.push(PendingCompletion {
1549 id: pending_completion_id,
1550 _task: task,
1551 });
1552 }
1553
1554 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1555 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1556 return;
1557 };
1558
1559 if !model.provider.is_authenticated(cx) {
1560 return;
1561 }
1562
1563 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1564 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1565 If the conversation is about a specific subject, include it in the title. \
1566 Be descriptive. DO NOT speak in the first person.";
1567
1568 let request = self.to_summarize_request(added_user_message.into());
1569
1570 self.pending_summary = cx.spawn(async move |this, cx| {
1571 async move {
1572 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1573 let (mut messages, usage) = stream.await?;
1574
1575 if let Some(usage) = usage {
1576 this.update(cx, |_thread, cx| {
1577 cx.emit(ThreadEvent::UsageUpdated(usage));
1578 })
1579 .ok();
1580 }
1581
1582 let mut new_summary = String::new();
1583 while let Some(message) = messages.stream.next().await {
1584 let text = message?;
1585 let mut lines = text.lines();
1586 new_summary.extend(lines.next());
1587
1588 // Stop if the LLM generated multiple lines.
1589 if lines.next().is_some() {
1590 break;
1591 }
1592 }
1593
1594 this.update(cx, |this, cx| {
1595 if !new_summary.is_empty() {
1596 this.summary = Some(new_summary.into());
1597 }
1598
1599 cx.emit(ThreadEvent::SummaryGenerated);
1600 })?;
1601
1602 anyhow::Ok(())
1603 }
1604 .log_err()
1605 .await
1606 });
1607 }
1608
1609 pub fn start_generating_detailed_summary_if_needed(
1610 &mut self,
1611 thread_store: WeakEntity<ThreadStore>,
1612 cx: &mut Context<Self>,
1613 ) {
1614 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1615 return;
1616 };
1617
1618 match &*self.detailed_summary_rx.borrow() {
1619 DetailedSummaryState::Generating { message_id, .. }
1620 | DetailedSummaryState::Generated { message_id, .. }
1621 if *message_id == last_message_id =>
1622 {
1623 // Already up-to-date
1624 return;
1625 }
1626 _ => {}
1627 }
1628
1629 let Some(ConfiguredModel { model, provider }) =
1630 LanguageModelRegistry::read_global(cx).thread_summary_model()
1631 else {
1632 return;
1633 };
1634
1635 if !provider.is_authenticated(cx) {
1636 return;
1637 }
1638
1639 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1640 1. A brief overview of what was discussed\n\
1641 2. Key facts or information discovered\n\
1642 3. Outcomes or conclusions reached\n\
1643 4. Any action items or next steps if any\n\
1644 Format it in Markdown with headings and bullet points.";
1645
1646 let request = self.to_summarize_request(added_user_message.into());
1647
1648 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1649 message_id: last_message_id,
1650 };
1651
1652 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1653 // be better to allow the old task to complete, but this would require logic for choosing
1654 // which result to prefer (the old task could complete after the new one, resulting in a
1655 // stale summary).
1656 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1657 let stream = model.stream_completion_text(request, &cx);
1658 let Some(mut messages) = stream.await.log_err() else {
1659 thread
1660 .update(cx, |thread, _cx| {
1661 *thread.detailed_summary_tx.borrow_mut() =
1662 DetailedSummaryState::NotGenerated;
1663 })
1664 .ok()?;
1665 return None;
1666 };
1667
1668 let mut new_detailed_summary = String::new();
1669
1670 while let Some(chunk) = messages.stream.next().await {
1671 if let Some(chunk) = chunk.log_err() {
1672 new_detailed_summary.push_str(&chunk);
1673 }
1674 }
1675
1676 thread
1677 .update(cx, |thread, _cx| {
1678 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1679 text: new_detailed_summary.into(),
1680 message_id: last_message_id,
1681 };
1682 })
1683 .ok()?;
1684
1685 // Save thread so its summary can be reused later
1686 if let Some(thread) = thread.upgrade() {
1687 if let Ok(Ok(save_task)) = cx.update(|cx| {
1688 thread_store
1689 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1690 }) {
1691 save_task.await.log_err();
1692 }
1693 }
1694
1695 Some(())
1696 });
1697 }
1698
1699 pub async fn wait_for_detailed_summary_or_text(
1700 this: &Entity<Self>,
1701 cx: &mut AsyncApp,
1702 ) -> Option<SharedString> {
1703 let mut detailed_summary_rx = this
1704 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1705 .ok()?;
1706 loop {
1707 match detailed_summary_rx.recv().await? {
1708 DetailedSummaryState::Generating { .. } => {}
1709 DetailedSummaryState::NotGenerated => {
1710 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1711 }
1712 DetailedSummaryState::Generated { text, .. } => return Some(text),
1713 }
1714 }
1715 }
1716
1717 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1718 self.detailed_summary_rx
1719 .borrow()
1720 .text()
1721 .unwrap_or_else(|| self.text().into())
1722 }
1723
1724 pub fn is_generating_detailed_summary(&self) -> bool {
1725 matches!(
1726 &*self.detailed_summary_rx.borrow(),
1727 DetailedSummaryState::Generating { .. }
1728 )
1729 }
1730
1731 pub fn use_pending_tools(
1732 &mut self,
1733 window: Option<AnyWindowHandle>,
1734 cx: &mut Context<Self>,
1735 model: Arc<dyn LanguageModel>,
1736 ) -> Vec<PendingToolUse> {
1737 self.auto_capture_telemetry(cx);
1738 let request = self.to_completion_request(model, cx);
1739 let messages = Arc::new(request.messages);
1740 let pending_tool_uses = self
1741 .tool_use
1742 .pending_tool_uses()
1743 .into_iter()
1744 .filter(|tool_use| tool_use.status.is_idle())
1745 .cloned()
1746 .collect::<Vec<_>>();
1747
1748 for tool_use in pending_tool_uses.iter() {
1749 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1750 if tool.needs_confirmation(&tool_use.input, cx)
1751 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1752 {
1753 self.tool_use.confirm_tool_use(
1754 tool_use.id.clone(),
1755 tool_use.ui_text.clone(),
1756 tool_use.input.clone(),
1757 messages.clone(),
1758 tool,
1759 );
1760 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1761 } else {
1762 self.run_tool(
1763 tool_use.id.clone(),
1764 tool_use.ui_text.clone(),
1765 tool_use.input.clone(),
1766 &messages,
1767 tool,
1768 window,
1769 cx,
1770 );
1771 }
1772 }
1773 }
1774
1775 pending_tool_uses
1776 }
1777
1778 pub fn receive_invalid_tool_json(
1779 &mut self,
1780 tool_use_id: LanguageModelToolUseId,
1781 tool_name: Arc<str>,
1782 invalid_json: Arc<str>,
1783 error: String,
1784 window: Option<AnyWindowHandle>,
1785 cx: &mut Context<Thread>,
1786 ) {
1787 log::error!("The model returned invalid input JSON: {invalid_json}");
1788
1789 let pending_tool_use = self.tool_use.insert_tool_output(
1790 tool_use_id.clone(),
1791 tool_name,
1792 Err(anyhow!("Error parsing input JSON: {error}")),
1793 self.configured_model.as_ref(),
1794 );
1795 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1796 pending_tool_use.ui_text.clone()
1797 } else {
1798 log::error!(
1799 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1800 );
1801 format!("Unknown tool {}", tool_use_id).into()
1802 };
1803
1804 cx.emit(ThreadEvent::InvalidToolInput {
1805 tool_use_id: tool_use_id.clone(),
1806 ui_text,
1807 invalid_input_json: invalid_json,
1808 });
1809
1810 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1811 }
1812
1813 pub fn run_tool(
1814 &mut self,
1815 tool_use_id: LanguageModelToolUseId,
1816 ui_text: impl Into<SharedString>,
1817 input: serde_json::Value,
1818 messages: &[LanguageModelRequestMessage],
1819 tool: Arc<dyn Tool>,
1820 window: Option<AnyWindowHandle>,
1821 cx: &mut Context<Thread>,
1822 ) {
1823 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1824 self.tool_use
1825 .run_pending_tool(tool_use_id, ui_text.into(), task);
1826 }
1827
1828 fn spawn_tool_use(
1829 &mut self,
1830 tool_use_id: LanguageModelToolUseId,
1831 messages: &[LanguageModelRequestMessage],
1832 input: serde_json::Value,
1833 tool: Arc<dyn Tool>,
1834 window: Option<AnyWindowHandle>,
1835 cx: &mut Context<Thread>,
1836 ) -> Task<()> {
1837 let tool_name: Arc<str> = tool.name().into();
1838
1839 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1840 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1841 } else {
1842 tool.run(
1843 input,
1844 messages,
1845 self.project.clone(),
1846 self.action_log.clone(),
1847 window,
1848 cx,
1849 )
1850 };
1851
1852 // Store the card separately if it exists
1853 if let Some(card) = tool_result.card.clone() {
1854 self.tool_use
1855 .insert_tool_result_card(tool_use_id.clone(), card);
1856 }
1857
1858 cx.spawn({
1859 async move |thread: WeakEntity<Thread>, cx| {
1860 let output = tool_result.output.await;
1861
1862 thread
1863 .update(cx, |thread, cx| {
1864 let pending_tool_use = thread.tool_use.insert_tool_output(
1865 tool_use_id.clone(),
1866 tool_name,
1867 output,
1868 thread.configured_model.as_ref(),
1869 );
1870 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1871 })
1872 .ok();
1873 }
1874 })
1875 }
1876
1877 fn tool_finished(
1878 &mut self,
1879 tool_use_id: LanguageModelToolUseId,
1880 pending_tool_use: Option<PendingToolUse>,
1881 canceled: bool,
1882 window: Option<AnyWindowHandle>,
1883 cx: &mut Context<Self>,
1884 ) {
1885 if self.all_tools_finished() {
1886 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
1887 if !canceled {
1888 self.send_to_model(model.clone(), window, cx);
1889 }
1890 self.auto_capture_telemetry(cx);
1891 }
1892 }
1893
1894 cx.emit(ThreadEvent::ToolFinished {
1895 tool_use_id,
1896 pending_tool_use,
1897 });
1898 }
1899
1900 /// Cancels the last pending completion, if there are any pending.
1901 ///
1902 /// Returns whether a completion was canceled.
1903 pub fn cancel_last_completion(
1904 &mut self,
1905 window: Option<AnyWindowHandle>,
1906 cx: &mut Context<Self>,
1907 ) -> bool {
1908 let mut canceled = self.pending_completions.pop().is_some();
1909
1910 for pending_tool_use in self.tool_use.cancel_pending() {
1911 canceled = true;
1912 self.tool_finished(
1913 pending_tool_use.id.clone(),
1914 Some(pending_tool_use),
1915 true,
1916 window,
1917 cx,
1918 );
1919 }
1920
1921 self.finalize_pending_checkpoint(cx);
1922 canceled
1923 }
1924
1925 /// Signals that any in-progress editing should be canceled.
1926 ///
1927 /// This method is used to notify listeners (like ActiveThread) that
1928 /// they should cancel any editing operations.
1929 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
1930 cx.emit(ThreadEvent::CancelEditing);
1931 }
1932
1933 pub fn feedback(&self) -> Option<ThreadFeedback> {
1934 self.feedback
1935 }
1936
1937 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1938 self.message_feedback.get(&message_id).copied()
1939 }
1940
1941 pub fn report_message_feedback(
1942 &mut self,
1943 message_id: MessageId,
1944 feedback: ThreadFeedback,
1945 cx: &mut Context<Self>,
1946 ) -> Task<Result<()>> {
1947 if self.message_feedback.get(&message_id) == Some(&feedback) {
1948 return Task::ready(Ok(()));
1949 }
1950
1951 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1952 let serialized_thread = self.serialize(cx);
1953 let thread_id = self.id().clone();
1954 let client = self.project.read(cx).client();
1955
1956 let enabled_tool_names: Vec<String> = self
1957 .tools()
1958 .read(cx)
1959 .enabled_tools(cx)
1960 .iter()
1961 .map(|tool| tool.name().to_string())
1962 .collect();
1963
1964 self.message_feedback.insert(message_id, feedback);
1965
1966 cx.notify();
1967
1968 let message_content = self
1969 .message(message_id)
1970 .map(|msg| msg.to_string())
1971 .unwrap_or_default();
1972
1973 cx.background_spawn(async move {
1974 let final_project_snapshot = final_project_snapshot.await;
1975 let serialized_thread = serialized_thread.await?;
1976 let thread_data =
1977 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1978
1979 let rating = match feedback {
1980 ThreadFeedback::Positive => "positive",
1981 ThreadFeedback::Negative => "negative",
1982 };
1983 telemetry::event!(
1984 "Assistant Thread Rated",
1985 rating,
1986 thread_id,
1987 enabled_tool_names,
1988 message_id = message_id.0,
1989 message_content,
1990 thread_data,
1991 final_project_snapshot
1992 );
1993 client.telemetry().flush_events().await;
1994
1995 Ok(())
1996 })
1997 }
1998
1999 pub fn report_feedback(
2000 &mut self,
2001 feedback: ThreadFeedback,
2002 cx: &mut Context<Self>,
2003 ) -> Task<Result<()>> {
2004 let last_assistant_message_id = self
2005 .messages
2006 .iter()
2007 .rev()
2008 .find(|msg| msg.role == Role::Assistant)
2009 .map(|msg| msg.id);
2010
2011 if let Some(message_id) = last_assistant_message_id {
2012 self.report_message_feedback(message_id, feedback, cx)
2013 } else {
2014 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2015 let serialized_thread = self.serialize(cx);
2016 let thread_id = self.id().clone();
2017 let client = self.project.read(cx).client();
2018 self.feedback = Some(feedback);
2019 cx.notify();
2020
2021 cx.background_spawn(async move {
2022 let final_project_snapshot = final_project_snapshot.await;
2023 let serialized_thread = serialized_thread.await?;
2024 let thread_data = serde_json::to_value(serialized_thread)
2025 .unwrap_or_else(|_| serde_json::Value::Null);
2026
2027 let rating = match feedback {
2028 ThreadFeedback::Positive => "positive",
2029 ThreadFeedback::Negative => "negative",
2030 };
2031 telemetry::event!(
2032 "Assistant Thread Rated",
2033 rating,
2034 thread_id,
2035 thread_data,
2036 final_project_snapshot
2037 );
2038 client.telemetry().flush_events().await;
2039
2040 Ok(())
2041 })
2042 }
2043 }
2044
2045 /// Create a snapshot of the current project state including git information and unsaved buffers.
2046 fn project_snapshot(
2047 project: Entity<Project>,
2048 cx: &mut Context<Self>,
2049 ) -> Task<Arc<ProjectSnapshot>> {
2050 let git_store = project.read(cx).git_store().clone();
2051 let worktree_snapshots: Vec<_> = project
2052 .read(cx)
2053 .visible_worktrees(cx)
2054 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2055 .collect();
2056
2057 cx.spawn(async move |_, cx| {
2058 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2059
2060 let mut unsaved_buffers = Vec::new();
2061 cx.update(|app_cx| {
2062 let buffer_store = project.read(app_cx).buffer_store();
2063 for buffer_handle in buffer_store.read(app_cx).buffers() {
2064 let buffer = buffer_handle.read(app_cx);
2065 if buffer.is_dirty() {
2066 if let Some(file) = buffer.file() {
2067 let path = file.path().to_string_lossy().to_string();
2068 unsaved_buffers.push(path);
2069 }
2070 }
2071 }
2072 })
2073 .ok();
2074
2075 Arc::new(ProjectSnapshot {
2076 worktree_snapshots,
2077 unsaved_buffer_paths: unsaved_buffers,
2078 timestamp: Utc::now(),
2079 })
2080 })
2081 }
2082
2083 fn worktree_snapshot(
2084 worktree: Entity<project::Worktree>,
2085 git_store: Entity<GitStore>,
2086 cx: &App,
2087 ) -> Task<WorktreeSnapshot> {
2088 cx.spawn(async move |cx| {
2089 // Get worktree path and snapshot
2090 let worktree_info = cx.update(|app_cx| {
2091 let worktree = worktree.read(app_cx);
2092 let path = worktree.abs_path().to_string_lossy().to_string();
2093 let snapshot = worktree.snapshot();
2094 (path, snapshot)
2095 });
2096
2097 let Ok((worktree_path, _snapshot)) = worktree_info else {
2098 return WorktreeSnapshot {
2099 worktree_path: String::new(),
2100 git_state: None,
2101 };
2102 };
2103
2104 let git_state = git_store
2105 .update(cx, |git_store, cx| {
2106 git_store
2107 .repositories()
2108 .values()
2109 .find(|repo| {
2110 repo.read(cx)
2111 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2112 .is_some()
2113 })
2114 .cloned()
2115 })
2116 .ok()
2117 .flatten()
2118 .map(|repo| {
2119 repo.update(cx, |repo, _| {
2120 let current_branch =
2121 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2122 repo.send_job(None, |state, _| async move {
2123 let RepositoryState::Local { backend, .. } = state else {
2124 return GitState {
2125 remote_url: None,
2126 head_sha: None,
2127 current_branch,
2128 diff: None,
2129 };
2130 };
2131
2132 let remote_url = backend.remote_url("origin");
2133 let head_sha = backend.head_sha().await;
2134 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2135
2136 GitState {
2137 remote_url,
2138 head_sha,
2139 current_branch,
2140 diff,
2141 }
2142 })
2143 })
2144 });
2145
2146 let git_state = match git_state {
2147 Some(git_state) => match git_state.ok() {
2148 Some(git_state) => git_state.await.ok(),
2149 None => None,
2150 },
2151 None => None,
2152 };
2153
2154 WorktreeSnapshot {
2155 worktree_path,
2156 git_state,
2157 }
2158 })
2159 }
2160
2161 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2162 let mut markdown = Vec::new();
2163
2164 if let Some(summary) = self.summary() {
2165 writeln!(markdown, "# {summary}\n")?;
2166 };
2167
2168 for message in self.messages() {
2169 writeln!(
2170 markdown,
2171 "## {role}\n",
2172 role = match message.role {
2173 Role::User => "User",
2174 Role::Assistant => "Assistant",
2175 Role::System => "System",
2176 }
2177 )?;
2178
2179 if !message.loaded_context.text.is_empty() {
2180 writeln!(markdown, "{}", message.loaded_context.text)?;
2181 }
2182
2183 if !message.loaded_context.images.is_empty() {
2184 writeln!(
2185 markdown,
2186 "\n{} images attached as context.\n",
2187 message.loaded_context.images.len()
2188 )?;
2189 }
2190
2191 for segment in &message.segments {
2192 match segment {
2193 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2194 MessageSegment::Thinking { text, .. } => {
2195 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2196 }
2197 MessageSegment::RedactedThinking(_) => {}
2198 }
2199 }
2200
2201 for tool_use in self.tool_uses_for_message(message.id, cx) {
2202 writeln!(
2203 markdown,
2204 "**Use Tool: {} ({})**",
2205 tool_use.name, tool_use.id
2206 )?;
2207 writeln!(markdown, "```json")?;
2208 writeln!(
2209 markdown,
2210 "{}",
2211 serde_json::to_string_pretty(&tool_use.input)?
2212 )?;
2213 writeln!(markdown, "```")?;
2214 }
2215
2216 for tool_result in self.tool_results_for_message(message.id) {
2217 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2218 if tool_result.is_error {
2219 write!(markdown, " (Error)")?;
2220 }
2221
2222 writeln!(markdown, "**\n")?;
2223 writeln!(markdown, "{}", tool_result.content)?;
2224 }
2225 }
2226
2227 Ok(String::from_utf8_lossy(&markdown).to_string())
2228 }
2229
2230 pub fn keep_edits_in_range(
2231 &mut self,
2232 buffer: Entity<language::Buffer>,
2233 buffer_range: Range<language::Anchor>,
2234 cx: &mut Context<Self>,
2235 ) {
2236 self.action_log.update(cx, |action_log, cx| {
2237 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2238 });
2239 }
2240
2241 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2242 self.action_log
2243 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2244 }
2245
2246 pub fn reject_edits_in_ranges(
2247 &mut self,
2248 buffer: Entity<language::Buffer>,
2249 buffer_ranges: Vec<Range<language::Anchor>>,
2250 cx: &mut Context<Self>,
2251 ) -> Task<Result<()>> {
2252 self.action_log.update(cx, |action_log, cx| {
2253 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2254 })
2255 }
2256
2257 pub fn action_log(&self) -> &Entity<ActionLog> {
2258 &self.action_log
2259 }
2260
2261 pub fn project(&self) -> &Entity<Project> {
2262 &self.project
2263 }
2264
2265 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2266 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2267 return;
2268 }
2269
2270 let now = Instant::now();
2271 if let Some(last) = self.last_auto_capture_at {
2272 if now.duration_since(last).as_secs() < 10 {
2273 return;
2274 }
2275 }
2276
2277 self.last_auto_capture_at = Some(now);
2278
2279 let thread_id = self.id().clone();
2280 let github_login = self
2281 .project
2282 .read(cx)
2283 .user_store()
2284 .read(cx)
2285 .current_user()
2286 .map(|user| user.github_login.clone());
2287 let client = self.project.read(cx).client().clone();
2288 let serialize_task = self.serialize(cx);
2289
2290 cx.background_executor()
2291 .spawn(async move {
2292 if let Ok(serialized_thread) = serialize_task.await {
2293 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2294 telemetry::event!(
2295 "Agent Thread Auto-Captured",
2296 thread_id = thread_id.to_string(),
2297 thread_data = thread_data,
2298 auto_capture_reason = "tracked_user",
2299 github_login = github_login
2300 );
2301
2302 client.telemetry().flush_events().await;
2303 }
2304 }
2305 })
2306 .detach();
2307 }
2308
2309 pub fn cumulative_token_usage(&self) -> TokenUsage {
2310 self.cumulative_token_usage
2311 }
2312
2313 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2314 let Some(model) = self.configured_model.as_ref() else {
2315 return TotalTokenUsage::default();
2316 };
2317
2318 let max = model.model.max_token_count();
2319
2320 let index = self
2321 .messages
2322 .iter()
2323 .position(|msg| msg.id == message_id)
2324 .unwrap_or(0);
2325
2326 if index == 0 {
2327 return TotalTokenUsage { total: 0, max };
2328 }
2329
2330 let token_usage = &self
2331 .request_token_usage
2332 .get(index - 1)
2333 .cloned()
2334 .unwrap_or_default();
2335
2336 TotalTokenUsage {
2337 total: token_usage.total_tokens() as usize,
2338 max,
2339 }
2340 }
2341
2342 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2343 let model = self.configured_model.as_ref()?;
2344
2345 let max = model.model.max_token_count();
2346
2347 if let Some(exceeded_error) = &self.exceeded_window_error {
2348 if model.model.id() == exceeded_error.model_id {
2349 return Some(TotalTokenUsage {
2350 total: exceeded_error.token_count,
2351 max,
2352 });
2353 }
2354 }
2355
2356 let total = self
2357 .token_usage_at_last_message()
2358 .unwrap_or_default()
2359 .total_tokens() as usize;
2360
2361 Some(TotalTokenUsage { total, max })
2362 }
2363
2364 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2365 self.request_token_usage
2366 .get(self.messages.len().saturating_sub(1))
2367 .or_else(|| self.request_token_usage.last())
2368 .cloned()
2369 }
2370
2371 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2372 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2373 self.request_token_usage
2374 .resize(self.messages.len(), placeholder);
2375
2376 if let Some(last) = self.request_token_usage.last_mut() {
2377 *last = token_usage;
2378 }
2379 }
2380
2381 pub fn deny_tool_use(
2382 &mut self,
2383 tool_use_id: LanguageModelToolUseId,
2384 tool_name: Arc<str>,
2385 window: Option<AnyWindowHandle>,
2386 cx: &mut Context<Self>,
2387 ) {
2388 let err = Err(anyhow::anyhow!(
2389 "Permission to run tool action denied by user"
2390 ));
2391
2392 self.tool_use.insert_tool_output(
2393 tool_use_id.clone(),
2394 tool_name,
2395 err,
2396 self.configured_model.as_ref(),
2397 );
2398 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2399 }
2400}
2401
2402#[derive(Debug, Clone, Error)]
2403pub enum ThreadError {
2404 #[error("Payment required")]
2405 PaymentRequired,
2406 #[error("Max monthly spend reached")]
2407 MaxMonthlySpendReached,
2408 #[error("Model request limit reached")]
2409 ModelRequestLimitReached { plan: Plan },
2410 #[error("Message {header}: {message}")]
2411 Message {
2412 header: SharedString,
2413 message: SharedString,
2414 },
2415}
2416
2417#[derive(Debug, Clone)]
2418pub enum ThreadEvent {
2419 ShowError(ThreadError),
2420 UsageUpdated(RequestUsage),
2421 StreamedCompletion,
2422 ReceivedTextChunk,
2423 StreamedAssistantText(MessageId, String),
2424 StreamedAssistantThinking(MessageId, String),
2425 StreamedToolUse {
2426 tool_use_id: LanguageModelToolUseId,
2427 ui_text: Arc<str>,
2428 input: serde_json::Value,
2429 },
2430 InvalidToolInput {
2431 tool_use_id: LanguageModelToolUseId,
2432 ui_text: Arc<str>,
2433 invalid_input_json: Arc<str>,
2434 },
2435 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2436 MessageAdded(MessageId),
2437 MessageEdited(MessageId),
2438 MessageDeleted(MessageId),
2439 SummaryGenerated,
2440 SummaryChanged,
2441 UsePendingTools {
2442 tool_uses: Vec<PendingToolUse>,
2443 },
2444 ToolFinished {
2445 #[allow(unused)]
2446 tool_use_id: LanguageModelToolUseId,
2447 /// The pending tool use that corresponds to this tool.
2448 pending_tool_use: Option<PendingToolUse>,
2449 },
2450 CheckpointChanged,
2451 ToolConfirmationNeeded,
2452 CancelEditing,
2453}
2454
2455impl EventEmitter<ThreadEvent> for Thread {}
2456
2457struct PendingCompletion {
2458 id: usize,
2459 _task: Task<()>,
2460}
2461
2462#[cfg(test)]
2463mod tests {
2464 use super::*;
2465 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2466 use assistant_settings::AssistantSettings;
2467 use assistant_tool::ToolRegistry;
2468 use context_server::ContextServerSettings;
2469 use editor::EditorSettings;
2470 use gpui::TestAppContext;
2471 use language_model::fake_provider::FakeLanguageModel;
2472 use project::{FakeFs, Project};
2473 use prompt_store::PromptBuilder;
2474 use serde_json::json;
2475 use settings::{Settings, SettingsStore};
2476 use std::sync::Arc;
2477 use theme::ThemeSettings;
2478 use util::path;
2479 use workspace::Workspace;
2480
2481 #[gpui::test]
2482 async fn test_message_with_context(cx: &mut TestAppContext) {
2483 init_test_settings(cx);
2484
2485 let project = create_test_project(
2486 cx,
2487 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2488 )
2489 .await;
2490
2491 let (_workspace, _thread_store, thread, context_store, model) =
2492 setup_test_environment(cx, project.clone()).await;
2493
2494 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2495 .await
2496 .unwrap();
2497
2498 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2499 let loaded_context = cx
2500 .update(|cx| load_context(vec![context], &project, &None, cx))
2501 .await;
2502
2503 // Insert user message with context
2504 let message_id = thread.update(cx, |thread, cx| {
2505 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2506 });
2507
2508 // Check content and context in message object
2509 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2510
2511 // Use different path format strings based on platform for the test
2512 #[cfg(windows)]
2513 let path_part = r"test\code.rs";
2514 #[cfg(not(windows))]
2515 let path_part = "test/code.rs";
2516
2517 let expected_context = format!(
2518 r#"
2519<context>
2520The following items were attached by the user. They are up-to-date and don't need to be re-read.
2521
2522<files>
2523```rs {path_part}
2524fn main() {{
2525 println!("Hello, world!");
2526}}
2527```
2528</files>
2529</context>
2530"#
2531 );
2532
2533 assert_eq!(message.role, Role::User);
2534 assert_eq!(message.segments.len(), 1);
2535 assert_eq!(
2536 message.segments[0],
2537 MessageSegment::Text("Please explain this code".to_string())
2538 );
2539 assert_eq!(message.loaded_context.text, expected_context);
2540
2541 // Check message in request
2542 let request = thread.update(cx, |thread, cx| {
2543 thread.to_completion_request(model.clone(), cx)
2544 });
2545
2546 assert_eq!(request.messages.len(), 2);
2547 let expected_full_message = format!("{}Please explain this code", expected_context);
2548 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2549 }
2550
2551 #[gpui::test]
2552 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2553 init_test_settings(cx);
2554
2555 let project = create_test_project(
2556 cx,
2557 json!({
2558 "file1.rs": "fn function1() {}\n",
2559 "file2.rs": "fn function2() {}\n",
2560 "file3.rs": "fn function3() {}\n",
2561 "file4.rs": "fn function4() {}\n",
2562 }),
2563 )
2564 .await;
2565
2566 let (_, _thread_store, thread, context_store, model) =
2567 setup_test_environment(cx, project.clone()).await;
2568
2569 // First message with context 1
2570 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2571 .await
2572 .unwrap();
2573 let new_contexts = context_store.update(cx, |store, cx| {
2574 store.new_context_for_thread(thread.read(cx), None)
2575 });
2576 assert_eq!(new_contexts.len(), 1);
2577 let loaded_context = cx
2578 .update(|cx| load_context(new_contexts, &project, &None, cx))
2579 .await;
2580 let message1_id = thread.update(cx, |thread, cx| {
2581 thread.insert_user_message("Message 1", loaded_context, None, cx)
2582 });
2583
2584 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2585 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2586 .await
2587 .unwrap();
2588 let new_contexts = context_store.update(cx, |store, cx| {
2589 store.new_context_for_thread(thread.read(cx), None)
2590 });
2591 assert_eq!(new_contexts.len(), 1);
2592 let loaded_context = cx
2593 .update(|cx| load_context(new_contexts, &project, &None, cx))
2594 .await;
2595 let message2_id = thread.update(cx, |thread, cx| {
2596 thread.insert_user_message("Message 2", loaded_context, None, cx)
2597 });
2598
2599 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2600 //
2601 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2602 .await
2603 .unwrap();
2604 let new_contexts = context_store.update(cx, |store, cx| {
2605 store.new_context_for_thread(thread.read(cx), None)
2606 });
2607 assert_eq!(new_contexts.len(), 1);
2608 let loaded_context = cx
2609 .update(|cx| load_context(new_contexts, &project, &None, cx))
2610 .await;
2611 let message3_id = thread.update(cx, |thread, cx| {
2612 thread.insert_user_message("Message 3", loaded_context, None, cx)
2613 });
2614
2615 // Check what contexts are included in each message
2616 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2617 (
2618 thread.message(message1_id).unwrap().clone(),
2619 thread.message(message2_id).unwrap().clone(),
2620 thread.message(message3_id).unwrap().clone(),
2621 )
2622 });
2623
2624 // First message should include context 1
2625 assert!(message1.loaded_context.text.contains("file1.rs"));
2626
2627 // Second message should include only context 2 (not 1)
2628 assert!(!message2.loaded_context.text.contains("file1.rs"));
2629 assert!(message2.loaded_context.text.contains("file2.rs"));
2630
2631 // Third message should include only context 3 (not 1 or 2)
2632 assert!(!message3.loaded_context.text.contains("file1.rs"));
2633 assert!(!message3.loaded_context.text.contains("file2.rs"));
2634 assert!(message3.loaded_context.text.contains("file3.rs"));
2635
2636 // Check entire request to make sure all contexts are properly included
2637 let request = thread.update(cx, |thread, cx| {
2638 thread.to_completion_request(model.clone(), cx)
2639 });
2640
2641 // The request should contain all 3 messages
2642 assert_eq!(request.messages.len(), 4);
2643
2644 // Check that the contexts are properly formatted in each message
2645 assert!(request.messages[1].string_contents().contains("file1.rs"));
2646 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2647 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2648
2649 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2650 assert!(request.messages[2].string_contents().contains("file2.rs"));
2651 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2652
2653 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2654 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2655 assert!(request.messages[3].string_contents().contains("file3.rs"));
2656
2657 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2658 .await
2659 .unwrap();
2660 let new_contexts = context_store.update(cx, |store, cx| {
2661 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2662 });
2663 assert_eq!(new_contexts.len(), 3);
2664 let loaded_context = cx
2665 .update(|cx| load_context(new_contexts, &project, &None, cx))
2666 .await
2667 .loaded_context;
2668
2669 assert!(!loaded_context.text.contains("file1.rs"));
2670 assert!(loaded_context.text.contains("file2.rs"));
2671 assert!(loaded_context.text.contains("file3.rs"));
2672 assert!(loaded_context.text.contains("file4.rs"));
2673
2674 let new_contexts = context_store.update(cx, |store, cx| {
2675 // Remove file4.rs
2676 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2677 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2678 });
2679 assert_eq!(new_contexts.len(), 2);
2680 let loaded_context = cx
2681 .update(|cx| load_context(new_contexts, &project, &None, cx))
2682 .await
2683 .loaded_context;
2684
2685 assert!(!loaded_context.text.contains("file1.rs"));
2686 assert!(loaded_context.text.contains("file2.rs"));
2687 assert!(loaded_context.text.contains("file3.rs"));
2688 assert!(!loaded_context.text.contains("file4.rs"));
2689
2690 let new_contexts = context_store.update(cx, |store, cx| {
2691 // Remove file3.rs
2692 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2693 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2694 });
2695 assert_eq!(new_contexts.len(), 1);
2696 let loaded_context = cx
2697 .update(|cx| load_context(new_contexts, &project, &None, cx))
2698 .await
2699 .loaded_context;
2700
2701 assert!(!loaded_context.text.contains("file1.rs"));
2702 assert!(loaded_context.text.contains("file2.rs"));
2703 assert!(!loaded_context.text.contains("file3.rs"));
2704 assert!(!loaded_context.text.contains("file4.rs"));
2705 }
2706
2707 #[gpui::test]
2708 async fn test_message_without_files(cx: &mut TestAppContext) {
2709 init_test_settings(cx);
2710
2711 let project = create_test_project(
2712 cx,
2713 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2714 )
2715 .await;
2716
2717 let (_, _thread_store, thread, _context_store, model) =
2718 setup_test_environment(cx, project.clone()).await;
2719
2720 // Insert user message without any context (empty context vector)
2721 let message_id = thread.update(cx, |thread, cx| {
2722 thread.insert_user_message(
2723 "What is the best way to learn Rust?",
2724 ContextLoadResult::default(),
2725 None,
2726 cx,
2727 )
2728 });
2729
2730 // Check content and context in message object
2731 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2732
2733 // Context should be empty when no files are included
2734 assert_eq!(message.role, Role::User);
2735 assert_eq!(message.segments.len(), 1);
2736 assert_eq!(
2737 message.segments[0],
2738 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2739 );
2740 assert_eq!(message.loaded_context.text, "");
2741
2742 // Check message in request
2743 let request = thread.update(cx, |thread, cx| {
2744 thread.to_completion_request(model.clone(), cx)
2745 });
2746
2747 assert_eq!(request.messages.len(), 2);
2748 assert_eq!(
2749 request.messages[1].string_contents(),
2750 "What is the best way to learn Rust?"
2751 );
2752
2753 // Add second message, also without context
2754 let message2_id = thread.update(cx, |thread, cx| {
2755 thread.insert_user_message(
2756 "Are there any good books?",
2757 ContextLoadResult::default(),
2758 None,
2759 cx,
2760 )
2761 });
2762
2763 let message2 =
2764 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2765 assert_eq!(message2.loaded_context.text, "");
2766
2767 // Check that both messages appear in the request
2768 let request = thread.update(cx, |thread, cx| {
2769 thread.to_completion_request(model.clone(), cx)
2770 });
2771
2772 assert_eq!(request.messages.len(), 3);
2773 assert_eq!(
2774 request.messages[1].string_contents(),
2775 "What is the best way to learn Rust?"
2776 );
2777 assert_eq!(
2778 request.messages[2].string_contents(),
2779 "Are there any good books?"
2780 );
2781 }
2782
2783 #[gpui::test]
2784 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2785 init_test_settings(cx);
2786
2787 let project = create_test_project(
2788 cx,
2789 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2790 )
2791 .await;
2792
2793 let (_workspace, _thread_store, thread, context_store, model) =
2794 setup_test_environment(cx, project.clone()).await;
2795
2796 // Open buffer and add it to context
2797 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2798 .await
2799 .unwrap();
2800
2801 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2802 let loaded_context = cx
2803 .update(|cx| load_context(vec![context], &project, &None, cx))
2804 .await;
2805
2806 // Insert user message with the buffer as context
2807 thread.update(cx, |thread, cx| {
2808 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2809 });
2810
2811 // Create a request and check that it doesn't have a stale buffer warning yet
2812 let initial_request = thread.update(cx, |thread, cx| {
2813 thread.to_completion_request(model.clone(), cx)
2814 });
2815
2816 // Make sure we don't have a stale file warning yet
2817 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2818 msg.string_contents()
2819 .contains("These files changed since last read:")
2820 });
2821 assert!(
2822 !has_stale_warning,
2823 "Should not have stale buffer warning before buffer is modified"
2824 );
2825
2826 // Modify the buffer
2827 buffer.update(cx, |buffer, cx| {
2828 // Find a position at the end of line 1
2829 buffer.edit(
2830 [(1..1, "\n println!(\"Added a new line\");\n")],
2831 None,
2832 cx,
2833 );
2834 });
2835
2836 // Insert another user message without context
2837 thread.update(cx, |thread, cx| {
2838 thread.insert_user_message(
2839 "What does the code do now?",
2840 ContextLoadResult::default(),
2841 None,
2842 cx,
2843 )
2844 });
2845
2846 // Create a new request and check for the stale buffer warning
2847 let new_request = thread.update(cx, |thread, cx| {
2848 thread.to_completion_request(model.clone(), cx)
2849 });
2850
2851 // We should have a stale file warning as the last message
2852 let last_message = new_request
2853 .messages
2854 .last()
2855 .expect("Request should have messages");
2856
2857 // The last message should be the stale buffer notification
2858 assert_eq!(last_message.role, Role::User);
2859
2860 // Check the exact content of the message
2861 let expected_content = "These files changed since last read:\n- code.rs\n";
2862 assert_eq!(
2863 last_message.string_contents(),
2864 expected_content,
2865 "Last message should be exactly the stale buffer notification"
2866 );
2867 }
2868
2869 fn init_test_settings(cx: &mut TestAppContext) {
2870 cx.update(|cx| {
2871 let settings_store = SettingsStore::test(cx);
2872 cx.set_global(settings_store);
2873 language::init(cx);
2874 Project::init_settings(cx);
2875 AssistantSettings::register(cx);
2876 prompt_store::init(cx);
2877 thread_store::init(cx);
2878 workspace::init_settings(cx);
2879 language_model::init_settings(cx);
2880 ThemeSettings::register(cx);
2881 ContextServerSettings::register(cx);
2882 EditorSettings::register(cx);
2883 ToolRegistry::default_global(cx);
2884 });
2885 }
2886
2887 // Helper to create a test project with test files
2888 async fn create_test_project(
2889 cx: &mut TestAppContext,
2890 files: serde_json::Value,
2891 ) -> Entity<Project> {
2892 let fs = FakeFs::new(cx.executor());
2893 fs.insert_tree(path!("/test"), files).await;
2894 Project::test(fs, [path!("/test").as_ref()], cx).await
2895 }
2896
2897 async fn setup_test_environment(
2898 cx: &mut TestAppContext,
2899 project: Entity<Project>,
2900 ) -> (
2901 Entity<Workspace>,
2902 Entity<ThreadStore>,
2903 Entity<Thread>,
2904 Entity<ContextStore>,
2905 Arc<dyn LanguageModel>,
2906 ) {
2907 let (workspace, cx) =
2908 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2909
2910 let thread_store = cx
2911 .update(|_, cx| {
2912 ThreadStore::load(
2913 project.clone(),
2914 cx.new(|_| ToolWorkingSet::default()),
2915 None,
2916 Arc::new(PromptBuilder::new(None).unwrap()),
2917 cx,
2918 )
2919 })
2920 .await
2921 .unwrap();
2922
2923 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2924 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2925
2926 let model = FakeLanguageModel::default();
2927 let model: Arc<dyn LanguageModel> = Arc::new(model);
2928
2929 (workspace, thread_store, thread, context_store, model)
2930 }
2931
2932 async fn add_file_to_context(
2933 project: &Entity<Project>,
2934 context_store: &Entity<ContextStore>,
2935 path: &str,
2936 cx: &mut TestAppContext,
2937 ) -> Result<Entity<language::Buffer>> {
2938 let buffer_path = project
2939 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2940 .unwrap();
2941
2942 let buffer = project
2943 .update(cx, |project, cx| {
2944 project.open_buffer(buffer_path.clone(), cx)
2945 })
2946 .await
2947 .unwrap();
2948
2949 context_store.update(cx, |context_store, cx| {
2950 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2951 });
2952
2953 Ok(buffer)
2954 }
2955}