1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
26 StopReason, TokenUsage,
27};
28use postage::stream::Stream as _;
29use project::Project;
30use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
31use prompt_store::{ModelContext, PromptBuilder};
32use proto::Plan;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use thiserror::Error;
37use util::{ResultExt as _, TryFutureExt as _, post_inc};
38use uuid::Uuid;
39use zed_llm_client::CompletionMode;
40
41use crate::ThreadStore;
42use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
43use crate::thread_store::{
44 SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
45 SerializedToolResult, SerializedToolUse, SharedProjectContext,
46};
47use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72/// The ID of the user prompt that initiated a request.
73///
74/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
75#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
76pub struct PromptId(Arc<str>);
77
78impl PromptId {
79 pub fn new() -> Self {
80 Self(Uuid::new_v4().to_string().into())
81 }
82}
83
84impl std::fmt::Display for PromptId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
91pub struct MessageId(pub(crate) usize);
92
93impl MessageId {
94 fn post_inc(&mut self) -> Self {
95 Self(post_inc(&mut self.0))
96 }
97}
98
99/// A message in a [`Thread`].
100#[derive(Debug, Clone)]
101pub struct Message {
102 pub id: MessageId,
103 pub role: Role,
104 pub segments: Vec<MessageSegment>,
105 pub loaded_context: LoadedContext,
106}
107
108impl Message {
109 /// Returns whether the message contains any meaningful text that should be displayed
110 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
111 pub fn should_display_content(&self) -> bool {
112 self.segments.iter().all(|segment| segment.should_display())
113 }
114
115 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
116 if let Some(MessageSegment::Thinking {
117 text: segment,
118 signature: current_signature,
119 }) = self.segments.last_mut()
120 {
121 if let Some(signature) = signature {
122 *current_signature = Some(signature);
123 }
124 segment.push_str(text);
125 } else {
126 self.segments.push(MessageSegment::Thinking {
127 text: text.to_string(),
128 signature,
129 });
130 }
131 }
132
133 pub fn push_text(&mut self, text: &str) {
134 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Text(text.to_string()));
138 }
139 }
140
141 pub fn to_string(&self) -> String {
142 let mut result = String::new();
143
144 if !self.loaded_context.text.is_empty() {
145 result.push_str(&self.loaded_context.text);
146 }
147
148 for segment in &self.segments {
149 match segment {
150 MessageSegment::Text(text) => result.push_str(text),
151 MessageSegment::Thinking { text, .. } => {
152 result.push_str("<think>\n");
153 result.push_str(text);
154 result.push_str("\n</think>");
155 }
156 MessageSegment::RedactedThinking(_) => {}
157 }
158 }
159
160 result
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum MessageSegment {
166 Text(String),
167 Thinking {
168 text: String,
169 signature: Option<String>,
170 },
171 RedactedThinking(Vec<u8>),
172}
173
174impl MessageSegment {
175 pub fn should_display(&self) -> bool {
176 match self {
177 Self::Text(text) => text.is_empty(),
178 Self::Thinking { text, .. } => text.is_empty(),
179 Self::RedactedThinking(_) => false,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ProjectSnapshot {
186 pub worktree_snapshots: Vec<WorktreeSnapshot>,
187 pub unsaved_buffer_paths: Vec<String>,
188 pub timestamp: DateTime<Utc>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct WorktreeSnapshot {
193 pub worktree_path: String,
194 pub git_state: Option<GitState>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct GitState {
199 pub remote_url: Option<String>,
200 pub head_sha: Option<String>,
201 pub current_branch: Option<String>,
202 pub diff: Option<String>,
203}
204
205#[derive(Clone)]
206pub struct ThreadCheckpoint {
207 message_id: MessageId,
208 git_checkpoint: GitStoreCheckpoint,
209}
210
211#[derive(Copy, Clone, Debug, PartialEq, Eq)]
212pub enum ThreadFeedback {
213 Positive,
214 Negative,
215}
216
217pub enum LastRestoreCheckpoint {
218 Pending {
219 message_id: MessageId,
220 },
221 Error {
222 message_id: MessageId,
223 error: String,
224 },
225}
226
227impl LastRestoreCheckpoint {
228 pub fn message_id(&self) -> MessageId {
229 match self {
230 LastRestoreCheckpoint::Pending { message_id } => *message_id,
231 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
232 }
233 }
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub enum DetailedSummaryState {
238 #[default]
239 NotGenerated,
240 Generating {
241 message_id: MessageId,
242 },
243 Generated {
244 text: SharedString,
245 message_id: MessageId,
246 },
247}
248
249impl DetailedSummaryState {
250 fn text(&self) -> Option<SharedString> {
251 if let Self::Generated { text, .. } = self {
252 Some(text.clone())
253 } else {
254 None
255 }
256 }
257}
258
259#[derive(Default)]
260pub struct TotalTokenUsage {
261 pub total: usize,
262 pub max: usize,
263}
264
265impl TotalTokenUsage {
266 pub fn ratio(&self) -> TokenUsageRatio {
267 #[cfg(debug_assertions)]
268 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
269 .unwrap_or("0.8".to_string())
270 .parse()
271 .unwrap();
272 #[cfg(not(debug_assertions))]
273 let warning_threshold: f32 = 0.8;
274
275 // When the maximum is unknown because there is no selected model,
276 // avoid showing the token limit warning.
277 if self.max == 0 {
278 TokenUsageRatio::Normal
279 } else if self.total >= self.max {
280 TokenUsageRatio::Exceeded
281 } else if self.total as f32 / self.max as f32 >= warning_threshold {
282 TokenUsageRatio::Warning
283 } else {
284 TokenUsageRatio::Normal
285 }
286 }
287
288 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
289 TotalTokenUsage {
290 total: self.total + tokens,
291 max: self.max,
292 }
293 }
294}
295
296#[derive(Debug, Default, PartialEq, Eq)]
297pub enum TokenUsageRatio {
298 #[default]
299 Normal,
300 Warning,
301 Exceeded,
302}
303
304/// A thread of conversation with the LLM.
305pub struct Thread {
306 id: ThreadId,
307 updated_at: DateTime<Utc>,
308 summary: Option<SharedString>,
309 pending_summary: Task<Option<()>>,
310 detailed_summary_task: Task<Option<()>>,
311 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
312 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
313 completion_mode: Option<CompletionMode>,
314 messages: Vec<Message>,
315 next_message_id: MessageId,
316 last_prompt_id: PromptId,
317 project_context: SharedProjectContext,
318 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
319 completion_count: usize,
320 pending_completions: Vec<PendingCompletion>,
321 project: Entity<Project>,
322 prompt_builder: Arc<PromptBuilder>,
323 tools: Entity<ToolWorkingSet>,
324 tool_use: ToolUseState,
325 action_log: Entity<ActionLog>,
326 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
327 pending_checkpoint: Option<ThreadCheckpoint>,
328 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
329 request_token_usage: Vec<TokenUsage>,
330 cumulative_token_usage: TokenUsage,
331 exceeded_window_error: Option<ExceededWindowError>,
332 feedback: Option<ThreadFeedback>,
333 message_feedback: HashMap<MessageId, ThreadFeedback>,
334 last_auto_capture_at: Option<Instant>,
335 request_callback: Option<
336 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
337 >,
338 remaining_turns: u32,
339 configured_model: Option<ConfiguredModel>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct ExceededWindowError {
344 /// Model used when last message exceeded context window
345 model_id: LanguageModelId,
346 /// Token count including last message
347 token_count: usize,
348}
349
350impl Thread {
351 pub fn new(
352 project: Entity<Project>,
353 tools: Entity<ToolWorkingSet>,
354 prompt_builder: Arc<PromptBuilder>,
355 system_prompt: SharedProjectContext,
356 cx: &mut Context<Self>,
357 ) -> Self {
358 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
359 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
360
361 Self {
362 id: ThreadId::new(),
363 updated_at: Utc::now(),
364 summary: None,
365 pending_summary: Task::ready(None),
366 detailed_summary_task: Task::ready(None),
367 detailed_summary_tx,
368 detailed_summary_rx,
369 completion_mode: None,
370 messages: Vec::new(),
371 next_message_id: MessageId(0),
372 last_prompt_id: PromptId::new(),
373 project_context: system_prompt,
374 checkpoints_by_message: HashMap::default(),
375 completion_count: 0,
376 pending_completions: Vec::new(),
377 project: project.clone(),
378 prompt_builder,
379 tools: tools.clone(),
380 last_restore_checkpoint: None,
381 pending_checkpoint: None,
382 tool_use: ToolUseState::new(tools.clone()),
383 action_log: cx.new(|_| ActionLog::new(project.clone())),
384 initial_project_snapshot: {
385 let project_snapshot = Self::project_snapshot(project, cx);
386 cx.foreground_executor()
387 .spawn(async move { Some(project_snapshot.await) })
388 .shared()
389 },
390 request_token_usage: Vec::new(),
391 cumulative_token_usage: TokenUsage::default(),
392 exceeded_window_error: None,
393 feedback: None,
394 message_feedback: HashMap::default(),
395 last_auto_capture_at: None,
396 request_callback: None,
397 remaining_turns: u32::MAX,
398 configured_model,
399 }
400 }
401
402 pub fn deserialize(
403 id: ThreadId,
404 serialized: SerializedThread,
405 project: Entity<Project>,
406 tools: Entity<ToolWorkingSet>,
407 prompt_builder: Arc<PromptBuilder>,
408 project_context: SharedProjectContext,
409 cx: &mut Context<Self>,
410 ) -> Self {
411 let next_message_id = MessageId(
412 serialized
413 .messages
414 .last()
415 .map(|message| message.id.0 + 1)
416 .unwrap_or(0),
417 );
418 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
419 let (detailed_summary_tx, detailed_summary_rx) =
420 postage::watch::channel_with(serialized.detailed_summary_state);
421
422 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
423 serialized
424 .model
425 .and_then(|model| {
426 let model = SelectedModel {
427 provider: model.provider.clone().into(),
428 model: model.model.clone().into(),
429 };
430 registry.select_model(&model, cx)
431 })
432 .or_else(|| registry.default_model())
433 });
434
435 Self {
436 id,
437 updated_at: serialized.updated_at,
438 summary: Some(serialized.summary),
439 pending_summary: Task::ready(None),
440 detailed_summary_task: Task::ready(None),
441 detailed_summary_tx,
442 detailed_summary_rx,
443 completion_mode: None,
444 messages: serialized
445 .messages
446 .into_iter()
447 .map(|message| Message {
448 id: message.id,
449 role: message.role,
450 segments: message
451 .segments
452 .into_iter()
453 .map(|segment| match segment {
454 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
455 SerializedMessageSegment::Thinking { text, signature } => {
456 MessageSegment::Thinking { text, signature }
457 }
458 SerializedMessageSegment::RedactedThinking { data } => {
459 MessageSegment::RedactedThinking(data)
460 }
461 })
462 .collect(),
463 loaded_context: LoadedContext {
464 contexts: Vec::new(),
465 text: message.context,
466 images: Vec::new(),
467 },
468 })
469 .collect(),
470 next_message_id,
471 last_prompt_id: PromptId::new(),
472 project_context,
473 checkpoints_by_message: HashMap::default(),
474 completion_count: 0,
475 pending_completions: Vec::new(),
476 last_restore_checkpoint: None,
477 pending_checkpoint: None,
478 project: project.clone(),
479 prompt_builder,
480 tools,
481 tool_use,
482 action_log: cx.new(|_| ActionLog::new(project)),
483 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
484 request_token_usage: serialized.request_token_usage,
485 cumulative_token_usage: serialized.cumulative_token_usage,
486 exceeded_window_error: None,
487 feedback: None,
488 message_feedback: HashMap::default(),
489 last_auto_capture_at: None,
490 request_callback: None,
491 remaining_turns: u32::MAX,
492 configured_model,
493 }
494 }
495
496 pub fn set_request_callback(
497 &mut self,
498 callback: impl 'static
499 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
500 ) {
501 self.request_callback = Some(Box::new(callback));
502 }
503
504 pub fn id(&self) -> &ThreadId {
505 &self.id
506 }
507
508 pub fn is_empty(&self) -> bool {
509 self.messages.is_empty()
510 }
511
512 pub fn updated_at(&self) -> DateTime<Utc> {
513 self.updated_at
514 }
515
516 pub fn touch_updated_at(&mut self) {
517 self.updated_at = Utc::now();
518 }
519
520 pub fn advance_prompt_id(&mut self) {
521 self.last_prompt_id = PromptId::new();
522 }
523
524 pub fn summary(&self) -> Option<SharedString> {
525 self.summary.clone()
526 }
527
528 pub fn project_context(&self) -> SharedProjectContext {
529 self.project_context.clone()
530 }
531
532 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
533 if self.configured_model.is_none() {
534 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
535 }
536 self.configured_model.clone()
537 }
538
539 pub fn configured_model(&self) -> Option<ConfiguredModel> {
540 self.configured_model.clone()
541 }
542
543 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
544 self.configured_model = model;
545 cx.notify();
546 }
547
548 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
549
550 pub fn summary_or_default(&self) -> SharedString {
551 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
552 }
553
554 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
555 let Some(current_summary) = &self.summary else {
556 // Don't allow setting summary until generated
557 return;
558 };
559
560 let mut new_summary = new_summary.into();
561
562 if new_summary.is_empty() {
563 new_summary = Self::DEFAULT_SUMMARY;
564 }
565
566 if current_summary != &new_summary {
567 self.summary = Some(new_summary);
568 cx.emit(ThreadEvent::SummaryChanged);
569 }
570 }
571
572 pub fn completion_mode(&self) -> Option<CompletionMode> {
573 self.completion_mode
574 }
575
576 pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
577 self.completion_mode = mode;
578 }
579
580 pub fn message(&self, id: MessageId) -> Option<&Message> {
581 let index = self
582 .messages
583 .binary_search_by(|message| message.id.cmp(&id))
584 .ok()?;
585
586 self.messages.get(index)
587 }
588
589 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
590 self.messages.iter()
591 }
592
593 pub fn is_generating(&self) -> bool {
594 !self.pending_completions.is_empty() || !self.all_tools_finished()
595 }
596
597 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
598 &self.tools
599 }
600
601 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
602 self.tool_use
603 .pending_tool_uses()
604 .into_iter()
605 .find(|tool_use| &tool_use.id == id)
606 }
607
608 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
609 self.tool_use
610 .pending_tool_uses()
611 .into_iter()
612 .filter(|tool_use| tool_use.status.needs_confirmation())
613 }
614
615 pub fn has_pending_tool_uses(&self) -> bool {
616 !self.tool_use.pending_tool_uses().is_empty()
617 }
618
619 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
620 self.checkpoints_by_message.get(&id).cloned()
621 }
622
623 pub fn restore_checkpoint(
624 &mut self,
625 checkpoint: ThreadCheckpoint,
626 cx: &mut Context<Self>,
627 ) -> Task<Result<()>> {
628 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
629 message_id: checkpoint.message_id,
630 });
631 cx.emit(ThreadEvent::CheckpointChanged);
632 cx.notify();
633
634 let git_store = self.project().read(cx).git_store().clone();
635 let restore = git_store.update(cx, |git_store, cx| {
636 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
637 });
638
639 cx.spawn(async move |this, cx| {
640 let result = restore.await;
641 this.update(cx, |this, cx| {
642 if let Err(err) = result.as_ref() {
643 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
644 message_id: checkpoint.message_id,
645 error: err.to_string(),
646 });
647 } else {
648 this.truncate(checkpoint.message_id, cx);
649 this.last_restore_checkpoint = None;
650 }
651 this.pending_checkpoint = None;
652 cx.emit(ThreadEvent::CheckpointChanged);
653 cx.notify();
654 })?;
655 result
656 })
657 }
658
659 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
660 let pending_checkpoint = if self.is_generating() {
661 return;
662 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
663 checkpoint
664 } else {
665 return;
666 };
667
668 let git_store = self.project.read(cx).git_store().clone();
669 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
670 cx.spawn(async move |this, cx| match final_checkpoint.await {
671 Ok(final_checkpoint) => {
672 let equal = git_store
673 .update(cx, |store, cx| {
674 store.compare_checkpoints(
675 pending_checkpoint.git_checkpoint.clone(),
676 final_checkpoint.clone(),
677 cx,
678 )
679 })?
680 .await
681 .unwrap_or(false);
682
683 if !equal {
684 this.update(cx, |this, cx| {
685 this.insert_checkpoint(pending_checkpoint, cx)
686 })?;
687 }
688
689 Ok(())
690 }
691 Err(_) => this.update(cx, |this, cx| {
692 this.insert_checkpoint(pending_checkpoint, cx)
693 }),
694 })
695 .detach();
696 }
697
698 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
699 self.checkpoints_by_message
700 .insert(checkpoint.message_id, checkpoint);
701 cx.emit(ThreadEvent::CheckpointChanged);
702 cx.notify();
703 }
704
705 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
706 self.last_restore_checkpoint.as_ref()
707 }
708
709 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
710 let Some(message_ix) = self
711 .messages
712 .iter()
713 .rposition(|message| message.id == message_id)
714 else {
715 return;
716 };
717 for deleted_message in self.messages.drain(message_ix..) {
718 self.checkpoints_by_message.remove(&deleted_message.id);
719 }
720 cx.notify();
721 }
722
723 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
724 self.messages
725 .iter()
726 .find(|message| message.id == id)
727 .into_iter()
728 .flat_map(|message| message.loaded_context.contexts.iter())
729 }
730
731 pub fn is_turn_end(&self, ix: usize) -> bool {
732 if self.messages.is_empty() {
733 return false;
734 }
735
736 if !self.is_generating() && ix == self.messages.len() - 1 {
737 return true;
738 }
739
740 let Some(message) = self.messages.get(ix) else {
741 return false;
742 };
743
744 if message.role != Role::Assistant {
745 return false;
746 }
747
748 self.messages
749 .get(ix + 1)
750 .and_then(|message| {
751 self.message(message.id)
752 .map(|next_message| next_message.role == Role::User)
753 })
754 .unwrap_or(false)
755 }
756
757 /// Returns whether all of the tool uses have finished running.
758 pub fn all_tools_finished(&self) -> bool {
759 // If the only pending tool uses left are the ones with errors, then
760 // that means that we've finished running all of the pending tools.
761 self.tool_use
762 .pending_tool_uses()
763 .iter()
764 .all(|tool_use| tool_use.status.is_error())
765 }
766
767 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
768 self.tool_use.tool_uses_for_message(id, cx)
769 }
770
771 pub fn tool_results_for_message(
772 &self,
773 assistant_message_id: MessageId,
774 ) -> Vec<&LanguageModelToolResult> {
775 self.tool_use.tool_results_for_message(assistant_message_id)
776 }
777
778 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
779 self.tool_use.tool_result(id)
780 }
781
782 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
783 Some(&self.tool_use.tool_result(id)?.content)
784 }
785
786 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
787 self.tool_use.tool_result_card(id).cloned()
788 }
789
790 /// Return tools that are both enabled and supported by the model
791 pub fn available_tools(
792 &self,
793 cx: &App,
794 model: Arc<dyn LanguageModel>,
795 ) -> Vec<LanguageModelRequestTool> {
796 if model.supports_tools() {
797 self.tools()
798 .read(cx)
799 .enabled_tools(cx)
800 .into_iter()
801 .filter_map(|tool| {
802 // Skip tools that cannot be supported
803 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
804 Some(LanguageModelRequestTool {
805 name: tool.name(),
806 description: tool.description(),
807 input_schema,
808 })
809 })
810 .collect()
811 } else {
812 Vec::default()
813 }
814 }
815
816 pub fn insert_user_message(
817 &mut self,
818 text: impl Into<String>,
819 loaded_context: ContextLoadResult,
820 git_checkpoint: Option<GitStoreCheckpoint>,
821 cx: &mut Context<Self>,
822 ) -> MessageId {
823 if !loaded_context.referenced_buffers.is_empty() {
824 self.action_log.update(cx, |log, cx| {
825 for buffer in loaded_context.referenced_buffers {
826 log.track_buffer(buffer, cx);
827 }
828 });
829 }
830
831 let message_id = self.insert_message(
832 Role::User,
833 vec![MessageSegment::Text(text.into())],
834 loaded_context.loaded_context,
835 cx,
836 );
837
838 if let Some(git_checkpoint) = git_checkpoint {
839 self.pending_checkpoint = Some(ThreadCheckpoint {
840 message_id,
841 git_checkpoint,
842 });
843 }
844
845 self.auto_capture_telemetry(cx);
846
847 message_id
848 }
849
850 pub fn insert_assistant_message(
851 &mut self,
852 segments: Vec<MessageSegment>,
853 cx: &mut Context<Self>,
854 ) -> MessageId {
855 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
856 }
857
858 pub fn insert_message(
859 &mut self,
860 role: Role,
861 segments: Vec<MessageSegment>,
862 loaded_context: LoadedContext,
863 cx: &mut Context<Self>,
864 ) -> MessageId {
865 let id = self.next_message_id.post_inc();
866 self.messages.push(Message {
867 id,
868 role,
869 segments,
870 loaded_context,
871 });
872 self.touch_updated_at();
873 cx.emit(ThreadEvent::MessageAdded(id));
874 id
875 }
876
877 pub fn edit_message(
878 &mut self,
879 id: MessageId,
880 new_role: Role,
881 new_segments: Vec<MessageSegment>,
882 loaded_context: Option<LoadedContext>,
883 cx: &mut Context<Self>,
884 ) -> bool {
885 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
886 return false;
887 };
888 message.role = new_role;
889 message.segments = new_segments;
890 if let Some(context) = loaded_context {
891 message.loaded_context = context;
892 }
893 self.touch_updated_at();
894 cx.emit(ThreadEvent::MessageEdited(id));
895 true
896 }
897
898 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
899 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
900 return false;
901 };
902 self.messages.remove(index);
903 self.touch_updated_at();
904 cx.emit(ThreadEvent::MessageDeleted(id));
905 true
906 }
907
908 /// Returns the representation of this [`Thread`] in a textual form.
909 ///
910 /// This is the representation we use when attaching a thread as context to another thread.
911 pub fn text(&self) -> String {
912 let mut text = String::new();
913
914 for message in &self.messages {
915 text.push_str(match message.role {
916 language_model::Role::User => "User:",
917 language_model::Role::Assistant => "Assistant:",
918 language_model::Role::System => "System:",
919 });
920 text.push('\n');
921
922 for segment in &message.segments {
923 match segment {
924 MessageSegment::Text(content) => text.push_str(content),
925 MessageSegment::Thinking { text: content, .. } => {
926 text.push_str(&format!("<think>{}</think>", content))
927 }
928 MessageSegment::RedactedThinking(_) => {}
929 }
930 }
931 text.push('\n');
932 }
933
934 text
935 }
936
937 /// Serializes this thread into a format for storage or telemetry.
938 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
939 let initial_project_snapshot = self.initial_project_snapshot.clone();
940 cx.spawn(async move |this, cx| {
941 let initial_project_snapshot = initial_project_snapshot.await;
942 this.read_with(cx, |this, cx| SerializedThread {
943 version: SerializedThread::VERSION.to_string(),
944 summary: this.summary_or_default(),
945 updated_at: this.updated_at(),
946 messages: this
947 .messages()
948 .map(|message| SerializedMessage {
949 id: message.id,
950 role: message.role,
951 segments: message
952 .segments
953 .iter()
954 .map(|segment| match segment {
955 MessageSegment::Text(text) => {
956 SerializedMessageSegment::Text { text: text.clone() }
957 }
958 MessageSegment::Thinking { text, signature } => {
959 SerializedMessageSegment::Thinking {
960 text: text.clone(),
961 signature: signature.clone(),
962 }
963 }
964 MessageSegment::RedactedThinking(data) => {
965 SerializedMessageSegment::RedactedThinking {
966 data: data.clone(),
967 }
968 }
969 })
970 .collect(),
971 tool_uses: this
972 .tool_uses_for_message(message.id, cx)
973 .into_iter()
974 .map(|tool_use| SerializedToolUse {
975 id: tool_use.id,
976 name: tool_use.name,
977 input: tool_use.input,
978 })
979 .collect(),
980 tool_results: this
981 .tool_results_for_message(message.id)
982 .into_iter()
983 .map(|tool_result| SerializedToolResult {
984 tool_use_id: tool_result.tool_use_id.clone(),
985 is_error: tool_result.is_error,
986 content: tool_result.content.clone(),
987 })
988 .collect(),
989 context: message.loaded_context.text.clone(),
990 })
991 .collect(),
992 initial_project_snapshot,
993 cumulative_token_usage: this.cumulative_token_usage,
994 request_token_usage: this.request_token_usage.clone(),
995 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
996 exceeded_window_error: this.exceeded_window_error.clone(),
997 model: this
998 .configured_model
999 .as_ref()
1000 .map(|model| SerializedLanguageModel {
1001 provider: model.provider.id().0.to_string(),
1002 model: model.model.id().0.to_string(),
1003 }),
1004 })
1005 })
1006 }
1007
1008 pub fn remaining_turns(&self) -> u32 {
1009 self.remaining_turns
1010 }
1011
1012 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1013 self.remaining_turns = remaining_turns;
1014 }
1015
1016 pub fn send_to_model(
1017 &mut self,
1018 model: Arc<dyn LanguageModel>,
1019 window: Option<AnyWindowHandle>,
1020 cx: &mut Context<Self>,
1021 ) {
1022 if self.remaining_turns == 0 {
1023 return;
1024 }
1025
1026 self.remaining_turns -= 1;
1027
1028 let request = self.to_completion_request(model.clone(), cx);
1029
1030 self.stream_completion(request, model, window, cx);
1031 }
1032
1033 pub fn used_tools_since_last_user_message(&self) -> bool {
1034 for message in self.messages.iter().rev() {
1035 if self.tool_use.message_has_tool_results(message.id) {
1036 return true;
1037 } else if message.role == Role::User {
1038 return false;
1039 }
1040 }
1041
1042 false
1043 }
1044
1045 pub fn to_completion_request(
1046 &self,
1047 model: Arc<dyn LanguageModel>,
1048 cx: &mut Context<Self>,
1049 ) -> LanguageModelRequest {
1050 let mut request = LanguageModelRequest {
1051 thread_id: Some(self.id.to_string()),
1052 prompt_id: Some(self.last_prompt_id.to_string()),
1053 mode: None,
1054 messages: vec![],
1055 tools: Vec::new(),
1056 stop: Vec::new(),
1057 temperature: None,
1058 };
1059
1060 let available_tools = self.available_tools(cx, model.clone());
1061 let available_tool_names = available_tools
1062 .iter()
1063 .map(|tool| tool.name.clone())
1064 .collect();
1065
1066 let model_context = &ModelContext {
1067 available_tools: available_tool_names,
1068 };
1069
1070 if let Some(project_context) = self.project_context.borrow().as_ref() {
1071 match self
1072 .prompt_builder
1073 .generate_assistant_system_prompt(project_context, model_context)
1074 {
1075 Err(err) => {
1076 let message = format!("{err:?}").into();
1077 log::error!("{message}");
1078 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1079 header: "Error generating system prompt".into(),
1080 message,
1081 }));
1082 }
1083 Ok(system_prompt) => {
1084 request.messages.push(LanguageModelRequestMessage {
1085 role: Role::System,
1086 content: vec![MessageContent::Text(system_prompt)],
1087 cache: true,
1088 });
1089 }
1090 }
1091 } else {
1092 let message = "Context for system prompt unexpectedly not ready.".into();
1093 log::error!("{message}");
1094 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1095 header: "Error generating system prompt".into(),
1096 message,
1097 }));
1098 }
1099
1100 for message in &self.messages {
1101 let mut request_message = LanguageModelRequestMessage {
1102 role: message.role,
1103 content: Vec::new(),
1104 cache: false,
1105 };
1106
1107 message
1108 .loaded_context
1109 .add_to_request_message(&mut request_message);
1110
1111 for segment in &message.segments {
1112 match segment {
1113 MessageSegment::Text(text) => {
1114 if !text.is_empty() {
1115 request_message
1116 .content
1117 .push(MessageContent::Text(text.into()));
1118 }
1119 }
1120 MessageSegment::Thinking { text, signature } => {
1121 if !text.is_empty() {
1122 request_message.content.push(MessageContent::Thinking {
1123 text: text.into(),
1124 signature: signature.clone(),
1125 });
1126 }
1127 }
1128 MessageSegment::RedactedThinking(data) => {
1129 request_message
1130 .content
1131 .push(MessageContent::RedactedThinking(data.clone()));
1132 }
1133 };
1134 }
1135
1136 self.tool_use
1137 .attach_tool_uses(message.id, &mut request_message);
1138
1139 request.messages.push(request_message);
1140
1141 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1142 request.messages.push(tool_results_message);
1143 }
1144 }
1145
1146 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1147 if let Some(last) = request.messages.last_mut() {
1148 last.cache = true;
1149 }
1150
1151 self.attached_tracked_files_state(&mut request.messages, cx);
1152
1153 request.tools = available_tools;
1154 request.mode = if model.supports_max_mode() {
1155 self.completion_mode
1156 } else {
1157 None
1158 };
1159
1160 request
1161 }
1162
1163 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1164 let mut request = LanguageModelRequest {
1165 thread_id: None,
1166 prompt_id: None,
1167 mode: None,
1168 messages: vec![],
1169 tools: Vec::new(),
1170 stop: Vec::new(),
1171 temperature: None,
1172 };
1173
1174 for message in &self.messages {
1175 let mut request_message = LanguageModelRequestMessage {
1176 role: message.role,
1177 content: Vec::new(),
1178 cache: false,
1179 };
1180
1181 for segment in &message.segments {
1182 match segment {
1183 MessageSegment::Text(text) => request_message
1184 .content
1185 .push(MessageContent::Text(text.clone())),
1186 MessageSegment::Thinking { .. } => {}
1187 MessageSegment::RedactedThinking(_) => {}
1188 }
1189 }
1190
1191 if request_message.content.is_empty() {
1192 continue;
1193 }
1194
1195 request.messages.push(request_message);
1196 }
1197
1198 request.messages.push(LanguageModelRequestMessage {
1199 role: Role::User,
1200 content: vec![MessageContent::Text(added_user_message)],
1201 cache: false,
1202 });
1203
1204 request
1205 }
1206
1207 fn attached_tracked_files_state(
1208 &self,
1209 messages: &mut Vec<LanguageModelRequestMessage>,
1210 cx: &App,
1211 ) {
1212 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1213
1214 let mut stale_message = String::new();
1215
1216 let action_log = self.action_log.read(cx);
1217
1218 for stale_file in action_log.stale_buffers(cx) {
1219 let Some(file) = stale_file.read(cx).file() else {
1220 continue;
1221 };
1222
1223 if stale_message.is_empty() {
1224 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1225 }
1226
1227 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1228 }
1229
1230 let mut content = Vec::with_capacity(2);
1231
1232 if !stale_message.is_empty() {
1233 content.push(stale_message.into());
1234 }
1235
1236 if !content.is_empty() {
1237 let context_message = LanguageModelRequestMessage {
1238 role: Role::User,
1239 content,
1240 cache: false,
1241 };
1242
1243 messages.push(context_message);
1244 }
1245 }
1246
1247 pub fn stream_completion(
1248 &mut self,
1249 request: LanguageModelRequest,
1250 model: Arc<dyn LanguageModel>,
1251 window: Option<AnyWindowHandle>,
1252 cx: &mut Context<Self>,
1253 ) {
1254 let pending_completion_id = post_inc(&mut self.completion_count);
1255 let mut request_callback_parameters = if self.request_callback.is_some() {
1256 Some((request.clone(), Vec::new()))
1257 } else {
1258 None
1259 };
1260 let prompt_id = self.last_prompt_id.clone();
1261 let tool_use_metadata = ToolUseMetadata {
1262 model: model.clone(),
1263 thread_id: self.id.clone(),
1264 prompt_id: prompt_id.clone(),
1265 };
1266
1267 let task = cx.spawn(async move |thread, cx| {
1268 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1269 let initial_token_usage =
1270 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1271 let stream_completion = async {
1272 let (mut events, usage) = stream_completion_future.await?;
1273
1274 let mut stop_reason = StopReason::EndTurn;
1275 let mut current_token_usage = TokenUsage::default();
1276
1277 if let Some(usage) = usage {
1278 thread
1279 .update(cx, |_thread, cx| {
1280 cx.emit(ThreadEvent::UsageUpdated(usage));
1281 })
1282 .ok();
1283 }
1284
1285 let mut request_assistant_message_id = None;
1286
1287 while let Some(event) = events.next().await {
1288 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1289 response_events
1290 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1291 }
1292
1293 thread.update(cx, |thread, cx| {
1294 let event = match event {
1295 Ok(event) => event,
1296 Err(LanguageModelCompletionError::BadInputJson {
1297 id,
1298 tool_name,
1299 raw_input: invalid_input_json,
1300 json_parse_error,
1301 }) => {
1302 thread.receive_invalid_tool_json(
1303 id,
1304 tool_name,
1305 invalid_input_json,
1306 json_parse_error,
1307 window,
1308 cx,
1309 );
1310 return Ok(());
1311 }
1312 Err(LanguageModelCompletionError::Other(error)) => {
1313 return Err(error);
1314 }
1315 };
1316
1317 match event {
1318 LanguageModelCompletionEvent::StartMessage { .. } => {
1319 request_assistant_message_id =
1320 Some(thread.insert_assistant_message(
1321 vec![MessageSegment::Text(String::new())],
1322 cx,
1323 ));
1324 }
1325 LanguageModelCompletionEvent::Stop(reason) => {
1326 stop_reason = reason;
1327 }
1328 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1329 thread.update_token_usage_at_last_message(token_usage);
1330 thread.cumulative_token_usage = thread.cumulative_token_usage
1331 + token_usage
1332 - current_token_usage;
1333 current_token_usage = token_usage;
1334 }
1335 LanguageModelCompletionEvent::Text(chunk) => {
1336 cx.emit(ThreadEvent::ReceivedTextChunk);
1337 if let Some(last_message) = thread.messages.last_mut() {
1338 if last_message.role == Role::Assistant
1339 && !thread.tool_use.has_tool_results(last_message.id)
1340 {
1341 last_message.push_text(&chunk);
1342 cx.emit(ThreadEvent::StreamedAssistantText(
1343 last_message.id,
1344 chunk,
1345 ));
1346 } else {
1347 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1348 // of a new Assistant response.
1349 //
1350 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1351 // will result in duplicating the text of the chunk in the rendered Markdown.
1352 request_assistant_message_id =
1353 Some(thread.insert_assistant_message(
1354 vec![MessageSegment::Text(chunk.to_string())],
1355 cx,
1356 ));
1357 };
1358 }
1359 }
1360 LanguageModelCompletionEvent::Thinking {
1361 text: chunk,
1362 signature,
1363 } => {
1364 if let Some(last_message) = thread.messages.last_mut() {
1365 if last_message.role == Role::Assistant
1366 && !thread.tool_use.has_tool_results(last_message.id)
1367 {
1368 last_message.push_thinking(&chunk, signature);
1369 cx.emit(ThreadEvent::StreamedAssistantThinking(
1370 last_message.id,
1371 chunk,
1372 ));
1373 } else {
1374 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1375 // of a new Assistant response.
1376 //
1377 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1378 // will result in duplicating the text of the chunk in the rendered Markdown.
1379 request_assistant_message_id =
1380 Some(thread.insert_assistant_message(
1381 vec![MessageSegment::Thinking {
1382 text: chunk.to_string(),
1383 signature,
1384 }],
1385 cx,
1386 ));
1387 };
1388 }
1389 }
1390 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1391 let last_assistant_message_id = request_assistant_message_id
1392 .unwrap_or_else(|| {
1393 let new_assistant_message_id =
1394 thread.insert_assistant_message(vec![], cx);
1395 request_assistant_message_id =
1396 Some(new_assistant_message_id);
1397 new_assistant_message_id
1398 });
1399
1400 let tool_use_id = tool_use.id.clone();
1401 let streamed_input = if tool_use.is_input_complete {
1402 None
1403 } else {
1404 Some((&tool_use.input).clone())
1405 };
1406
1407 let ui_text = thread.tool_use.request_tool_use(
1408 last_assistant_message_id,
1409 tool_use,
1410 tool_use_metadata.clone(),
1411 cx,
1412 );
1413
1414 if let Some(input) = streamed_input {
1415 cx.emit(ThreadEvent::StreamedToolUse {
1416 tool_use_id,
1417 ui_text,
1418 input,
1419 });
1420 }
1421 }
1422 }
1423
1424 thread.touch_updated_at();
1425 cx.emit(ThreadEvent::StreamedCompletion);
1426 cx.notify();
1427
1428 thread.auto_capture_telemetry(cx);
1429 Ok(())
1430 })??;
1431
1432 smol::future::yield_now().await;
1433 }
1434
1435 thread.update(cx, |thread, cx| {
1436 thread
1437 .pending_completions
1438 .retain(|completion| completion.id != pending_completion_id);
1439
1440 // If there is a response without tool use, summarize the message. Otherwise,
1441 // allow two tool uses before summarizing.
1442 if thread.summary.is_none()
1443 && thread.messages.len() >= 2
1444 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1445 {
1446 thread.summarize(cx);
1447 }
1448 })?;
1449
1450 anyhow::Ok(stop_reason)
1451 };
1452
1453 let result = stream_completion.await;
1454
1455 thread
1456 .update(cx, |thread, cx| {
1457 thread.finalize_pending_checkpoint(cx);
1458 match result.as_ref() {
1459 Ok(stop_reason) => match stop_reason {
1460 StopReason::ToolUse => {
1461 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1462 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1463 }
1464 StopReason::EndTurn => {}
1465 StopReason::MaxTokens => {}
1466 },
1467 Err(error) => {
1468 if error.is::<PaymentRequiredError>() {
1469 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1470 } else if error.is::<MaxMonthlySpendReachedError>() {
1471 cx.emit(ThreadEvent::ShowError(
1472 ThreadError::MaxMonthlySpendReached,
1473 ));
1474 } else if let Some(error) =
1475 error.downcast_ref::<ModelRequestLimitReachedError>()
1476 {
1477 cx.emit(ThreadEvent::ShowError(
1478 ThreadError::ModelRequestLimitReached { plan: error.plan },
1479 ));
1480 } else if let Some(known_error) =
1481 error.downcast_ref::<LanguageModelKnownError>()
1482 {
1483 match known_error {
1484 LanguageModelKnownError::ContextWindowLimitExceeded {
1485 tokens,
1486 } => {
1487 thread.exceeded_window_error = Some(ExceededWindowError {
1488 model_id: model.id(),
1489 token_count: *tokens,
1490 });
1491 cx.notify();
1492 }
1493 }
1494 } else {
1495 let error_message = error
1496 .chain()
1497 .map(|err| err.to_string())
1498 .collect::<Vec<_>>()
1499 .join("\n");
1500 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1501 header: "Error interacting with language model".into(),
1502 message: SharedString::from(error_message.clone()),
1503 }));
1504 }
1505
1506 thread.cancel_last_completion(window, cx);
1507 }
1508 }
1509 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1510
1511 if let Some((request_callback, (request, response_events))) = thread
1512 .request_callback
1513 .as_mut()
1514 .zip(request_callback_parameters.as_ref())
1515 {
1516 request_callback(request, response_events);
1517 }
1518
1519 thread.auto_capture_telemetry(cx);
1520
1521 if let Ok(initial_usage) = initial_token_usage {
1522 let usage = thread.cumulative_token_usage - initial_usage;
1523
1524 telemetry::event!(
1525 "Assistant Thread Completion",
1526 thread_id = thread.id().to_string(),
1527 prompt_id = prompt_id,
1528 model = model.telemetry_id(),
1529 model_provider = model.provider_id().to_string(),
1530 input_tokens = usage.input_tokens,
1531 output_tokens = usage.output_tokens,
1532 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1533 cache_read_input_tokens = usage.cache_read_input_tokens,
1534 );
1535 }
1536 })
1537 .ok();
1538 });
1539
1540 self.pending_completions.push(PendingCompletion {
1541 id: pending_completion_id,
1542 _task: task,
1543 });
1544 }
1545
1546 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1547 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1548 return;
1549 };
1550
1551 if !model.provider.is_authenticated(cx) {
1552 return;
1553 }
1554
1555 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1556 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1557 If the conversation is about a specific subject, include it in the title. \
1558 Be descriptive. DO NOT speak in the first person.";
1559
1560 let request = self.to_summarize_request(added_user_message.into());
1561
1562 self.pending_summary = cx.spawn(async move |this, cx| {
1563 async move {
1564 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1565 let (mut messages, usage) = stream.await?;
1566
1567 if let Some(usage) = usage {
1568 this.update(cx, |_thread, cx| {
1569 cx.emit(ThreadEvent::UsageUpdated(usage));
1570 })
1571 .ok();
1572 }
1573
1574 let mut new_summary = String::new();
1575 while let Some(message) = messages.stream.next().await {
1576 let text = message?;
1577 let mut lines = text.lines();
1578 new_summary.extend(lines.next());
1579
1580 // Stop if the LLM generated multiple lines.
1581 if lines.next().is_some() {
1582 break;
1583 }
1584 }
1585
1586 this.update(cx, |this, cx| {
1587 if !new_summary.is_empty() {
1588 this.summary = Some(new_summary.into());
1589 }
1590
1591 cx.emit(ThreadEvent::SummaryGenerated);
1592 })?;
1593
1594 anyhow::Ok(())
1595 }
1596 .log_err()
1597 .await
1598 });
1599 }
1600
1601 pub fn start_generating_detailed_summary_if_needed(
1602 &mut self,
1603 thread_store: WeakEntity<ThreadStore>,
1604 cx: &mut Context<Self>,
1605 ) {
1606 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1607 return;
1608 };
1609
1610 match &*self.detailed_summary_rx.borrow() {
1611 DetailedSummaryState::Generating { message_id, .. }
1612 | DetailedSummaryState::Generated { message_id, .. }
1613 if *message_id == last_message_id =>
1614 {
1615 // Already up-to-date
1616 return;
1617 }
1618 _ => {}
1619 }
1620
1621 let Some(ConfiguredModel { model, provider }) =
1622 LanguageModelRegistry::read_global(cx).thread_summary_model()
1623 else {
1624 return;
1625 };
1626
1627 if !provider.is_authenticated(cx) {
1628 return;
1629 }
1630
1631 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1632 1. A brief overview of what was discussed\n\
1633 2. Key facts or information discovered\n\
1634 3. Outcomes or conclusions reached\n\
1635 4. Any action items or next steps if any\n\
1636 Format it in Markdown with headings and bullet points.";
1637
1638 let request = self.to_summarize_request(added_user_message.into());
1639
1640 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1641 message_id: last_message_id,
1642 };
1643
1644 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1645 // be better to allow the old task to complete, but this would require logic for choosing
1646 // which result to prefer (the old task could complete after the new one, resulting in a
1647 // stale summary).
1648 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1649 let stream = model.stream_completion_text(request, &cx);
1650 let Some(mut messages) = stream.await.log_err() else {
1651 thread
1652 .update(cx, |thread, _cx| {
1653 *thread.detailed_summary_tx.borrow_mut() =
1654 DetailedSummaryState::NotGenerated;
1655 })
1656 .ok()?;
1657 return None;
1658 };
1659
1660 let mut new_detailed_summary = String::new();
1661
1662 while let Some(chunk) = messages.stream.next().await {
1663 if let Some(chunk) = chunk.log_err() {
1664 new_detailed_summary.push_str(&chunk);
1665 }
1666 }
1667
1668 thread
1669 .update(cx, |thread, _cx| {
1670 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1671 text: new_detailed_summary.into(),
1672 message_id: last_message_id,
1673 };
1674 })
1675 .ok()?;
1676
1677 // Save thread so its summary can be reused later
1678 if let Some(thread) = thread.upgrade() {
1679 if let Ok(Ok(save_task)) = cx.update(|cx| {
1680 thread_store
1681 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1682 }) {
1683 save_task.await.log_err();
1684 }
1685 }
1686
1687 Some(())
1688 });
1689 }
1690
1691 pub async fn wait_for_detailed_summary_or_text(
1692 this: &Entity<Self>,
1693 cx: &mut AsyncApp,
1694 ) -> Option<SharedString> {
1695 let mut detailed_summary_rx = this
1696 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1697 .ok()?;
1698 loop {
1699 match detailed_summary_rx.recv().await? {
1700 DetailedSummaryState::Generating { .. } => {}
1701 DetailedSummaryState::NotGenerated => {
1702 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1703 }
1704 DetailedSummaryState::Generated { text, .. } => return Some(text),
1705 }
1706 }
1707 }
1708
1709 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1710 self.detailed_summary_rx
1711 .borrow()
1712 .text()
1713 .unwrap_or_else(|| self.text().into())
1714 }
1715
1716 pub fn is_generating_detailed_summary(&self) -> bool {
1717 matches!(
1718 &*self.detailed_summary_rx.borrow(),
1719 DetailedSummaryState::Generating { .. }
1720 )
1721 }
1722
1723 pub fn use_pending_tools(
1724 &mut self,
1725 window: Option<AnyWindowHandle>,
1726 cx: &mut Context<Self>,
1727 model: Arc<dyn LanguageModel>,
1728 ) -> Vec<PendingToolUse> {
1729 self.auto_capture_telemetry(cx);
1730 let request = self.to_completion_request(model, cx);
1731 let messages = Arc::new(request.messages);
1732 let pending_tool_uses = self
1733 .tool_use
1734 .pending_tool_uses()
1735 .into_iter()
1736 .filter(|tool_use| tool_use.status.is_idle())
1737 .cloned()
1738 .collect::<Vec<_>>();
1739
1740 for tool_use in pending_tool_uses.iter() {
1741 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1742 if tool.needs_confirmation(&tool_use.input, cx)
1743 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1744 {
1745 self.tool_use.confirm_tool_use(
1746 tool_use.id.clone(),
1747 tool_use.ui_text.clone(),
1748 tool_use.input.clone(),
1749 messages.clone(),
1750 tool,
1751 );
1752 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1753 } else {
1754 self.run_tool(
1755 tool_use.id.clone(),
1756 tool_use.ui_text.clone(),
1757 tool_use.input.clone(),
1758 &messages,
1759 tool,
1760 window,
1761 cx,
1762 );
1763 }
1764 }
1765 }
1766
1767 pending_tool_uses
1768 }
1769
1770 pub fn receive_invalid_tool_json(
1771 &mut self,
1772 tool_use_id: LanguageModelToolUseId,
1773 tool_name: Arc<str>,
1774 invalid_json: Arc<str>,
1775 error: String,
1776 window: Option<AnyWindowHandle>,
1777 cx: &mut Context<Thread>,
1778 ) {
1779 log::error!("The model returned invalid input JSON: {invalid_json}");
1780
1781 let pending_tool_use = self.tool_use.insert_tool_output(
1782 tool_use_id.clone(),
1783 tool_name,
1784 Err(anyhow!("Error parsing input JSON: {error}")),
1785 self.configured_model.as_ref(),
1786 );
1787 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1788 pending_tool_use.ui_text.clone()
1789 } else {
1790 log::error!(
1791 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1792 );
1793 format!("Unknown tool {}", tool_use_id).into()
1794 };
1795
1796 cx.emit(ThreadEvent::InvalidToolInput {
1797 tool_use_id: tool_use_id.clone(),
1798 ui_text,
1799 invalid_input_json: invalid_json,
1800 });
1801
1802 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1803 }
1804
1805 pub fn run_tool(
1806 &mut self,
1807 tool_use_id: LanguageModelToolUseId,
1808 ui_text: impl Into<SharedString>,
1809 input: serde_json::Value,
1810 messages: &[LanguageModelRequestMessage],
1811 tool: Arc<dyn Tool>,
1812 window: Option<AnyWindowHandle>,
1813 cx: &mut Context<Thread>,
1814 ) {
1815 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1816 self.tool_use
1817 .run_pending_tool(tool_use_id, ui_text.into(), task);
1818 }
1819
1820 fn spawn_tool_use(
1821 &mut self,
1822 tool_use_id: LanguageModelToolUseId,
1823 messages: &[LanguageModelRequestMessage],
1824 input: serde_json::Value,
1825 tool: Arc<dyn Tool>,
1826 window: Option<AnyWindowHandle>,
1827 cx: &mut Context<Thread>,
1828 ) -> Task<()> {
1829 let tool_name: Arc<str> = tool.name().into();
1830
1831 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1832 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1833 } else {
1834 tool.run(
1835 input,
1836 messages,
1837 self.project.clone(),
1838 self.action_log.clone(),
1839 window,
1840 cx,
1841 )
1842 };
1843
1844 // Store the card separately if it exists
1845 if let Some(card) = tool_result.card.clone() {
1846 self.tool_use
1847 .insert_tool_result_card(tool_use_id.clone(), card);
1848 }
1849
1850 cx.spawn({
1851 async move |thread: WeakEntity<Thread>, cx| {
1852 let output = tool_result.output.await;
1853
1854 thread
1855 .update(cx, |thread, cx| {
1856 let pending_tool_use = thread.tool_use.insert_tool_output(
1857 tool_use_id.clone(),
1858 tool_name,
1859 output,
1860 thread.configured_model.as_ref(),
1861 );
1862 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1863 })
1864 .ok();
1865 }
1866 })
1867 }
1868
1869 fn tool_finished(
1870 &mut self,
1871 tool_use_id: LanguageModelToolUseId,
1872 pending_tool_use: Option<PendingToolUse>,
1873 canceled: bool,
1874 window: Option<AnyWindowHandle>,
1875 cx: &mut Context<Self>,
1876 ) {
1877 if self.all_tools_finished() {
1878 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
1879 if !canceled {
1880 self.send_to_model(model.clone(), window, cx);
1881 }
1882 self.auto_capture_telemetry(cx);
1883 }
1884 }
1885
1886 cx.emit(ThreadEvent::ToolFinished {
1887 tool_use_id,
1888 pending_tool_use,
1889 });
1890 }
1891
1892 /// Cancels the last pending completion, if there are any pending.
1893 ///
1894 /// Returns whether a completion was canceled.
1895 pub fn cancel_last_completion(
1896 &mut self,
1897 window: Option<AnyWindowHandle>,
1898 cx: &mut Context<Self>,
1899 ) -> bool {
1900 let mut canceled = self.pending_completions.pop().is_some();
1901
1902 for pending_tool_use in self.tool_use.cancel_pending() {
1903 canceled = true;
1904 self.tool_finished(
1905 pending_tool_use.id.clone(),
1906 Some(pending_tool_use),
1907 true,
1908 window,
1909 cx,
1910 );
1911 }
1912
1913 self.finalize_pending_checkpoint(cx);
1914 canceled
1915 }
1916
1917 /// Signals that any in-progress editing should be canceled.
1918 ///
1919 /// This method is used to notify listeners (like ActiveThread) that
1920 /// they should cancel any editing operations.
1921 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
1922 cx.emit(ThreadEvent::CancelEditing);
1923 }
1924
1925 pub fn feedback(&self) -> Option<ThreadFeedback> {
1926 self.feedback
1927 }
1928
1929 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1930 self.message_feedback.get(&message_id).copied()
1931 }
1932
1933 pub fn report_message_feedback(
1934 &mut self,
1935 message_id: MessageId,
1936 feedback: ThreadFeedback,
1937 cx: &mut Context<Self>,
1938 ) -> Task<Result<()>> {
1939 if self.message_feedback.get(&message_id) == Some(&feedback) {
1940 return Task::ready(Ok(()));
1941 }
1942
1943 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1944 let serialized_thread = self.serialize(cx);
1945 let thread_id = self.id().clone();
1946 let client = self.project.read(cx).client();
1947
1948 let enabled_tool_names: Vec<String> = self
1949 .tools()
1950 .read(cx)
1951 .enabled_tools(cx)
1952 .iter()
1953 .map(|tool| tool.name().to_string())
1954 .collect();
1955
1956 self.message_feedback.insert(message_id, feedback);
1957
1958 cx.notify();
1959
1960 let message_content = self
1961 .message(message_id)
1962 .map(|msg| msg.to_string())
1963 .unwrap_or_default();
1964
1965 cx.background_spawn(async move {
1966 let final_project_snapshot = final_project_snapshot.await;
1967 let serialized_thread = serialized_thread.await?;
1968 let thread_data =
1969 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1970
1971 let rating = match feedback {
1972 ThreadFeedback::Positive => "positive",
1973 ThreadFeedback::Negative => "negative",
1974 };
1975 telemetry::event!(
1976 "Assistant Thread Rated",
1977 rating,
1978 thread_id,
1979 enabled_tool_names,
1980 message_id = message_id.0,
1981 message_content,
1982 thread_data,
1983 final_project_snapshot
1984 );
1985 client.telemetry().flush_events().await;
1986
1987 Ok(())
1988 })
1989 }
1990
1991 pub fn report_feedback(
1992 &mut self,
1993 feedback: ThreadFeedback,
1994 cx: &mut Context<Self>,
1995 ) -> Task<Result<()>> {
1996 let last_assistant_message_id = self
1997 .messages
1998 .iter()
1999 .rev()
2000 .find(|msg| msg.role == Role::Assistant)
2001 .map(|msg| msg.id);
2002
2003 if let Some(message_id) = last_assistant_message_id {
2004 self.report_message_feedback(message_id, feedback, cx)
2005 } else {
2006 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2007 let serialized_thread = self.serialize(cx);
2008 let thread_id = self.id().clone();
2009 let client = self.project.read(cx).client();
2010 self.feedback = Some(feedback);
2011 cx.notify();
2012
2013 cx.background_spawn(async move {
2014 let final_project_snapshot = final_project_snapshot.await;
2015 let serialized_thread = serialized_thread.await?;
2016 let thread_data = serde_json::to_value(serialized_thread)
2017 .unwrap_or_else(|_| serde_json::Value::Null);
2018
2019 let rating = match feedback {
2020 ThreadFeedback::Positive => "positive",
2021 ThreadFeedback::Negative => "negative",
2022 };
2023 telemetry::event!(
2024 "Assistant Thread Rated",
2025 rating,
2026 thread_id,
2027 thread_data,
2028 final_project_snapshot
2029 );
2030 client.telemetry().flush_events().await;
2031
2032 Ok(())
2033 })
2034 }
2035 }
2036
2037 /// Create a snapshot of the current project state including git information and unsaved buffers.
2038 fn project_snapshot(
2039 project: Entity<Project>,
2040 cx: &mut Context<Self>,
2041 ) -> Task<Arc<ProjectSnapshot>> {
2042 let git_store = project.read(cx).git_store().clone();
2043 let worktree_snapshots: Vec<_> = project
2044 .read(cx)
2045 .visible_worktrees(cx)
2046 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2047 .collect();
2048
2049 cx.spawn(async move |_, cx| {
2050 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2051
2052 let mut unsaved_buffers = Vec::new();
2053 cx.update(|app_cx| {
2054 let buffer_store = project.read(app_cx).buffer_store();
2055 for buffer_handle in buffer_store.read(app_cx).buffers() {
2056 let buffer = buffer_handle.read(app_cx);
2057 if buffer.is_dirty() {
2058 if let Some(file) = buffer.file() {
2059 let path = file.path().to_string_lossy().to_string();
2060 unsaved_buffers.push(path);
2061 }
2062 }
2063 }
2064 })
2065 .ok();
2066
2067 Arc::new(ProjectSnapshot {
2068 worktree_snapshots,
2069 unsaved_buffer_paths: unsaved_buffers,
2070 timestamp: Utc::now(),
2071 })
2072 })
2073 }
2074
2075 fn worktree_snapshot(
2076 worktree: Entity<project::Worktree>,
2077 git_store: Entity<GitStore>,
2078 cx: &App,
2079 ) -> Task<WorktreeSnapshot> {
2080 cx.spawn(async move |cx| {
2081 // Get worktree path and snapshot
2082 let worktree_info = cx.update(|app_cx| {
2083 let worktree = worktree.read(app_cx);
2084 let path = worktree.abs_path().to_string_lossy().to_string();
2085 let snapshot = worktree.snapshot();
2086 (path, snapshot)
2087 });
2088
2089 let Ok((worktree_path, _snapshot)) = worktree_info else {
2090 return WorktreeSnapshot {
2091 worktree_path: String::new(),
2092 git_state: None,
2093 };
2094 };
2095
2096 let git_state = git_store
2097 .update(cx, |git_store, cx| {
2098 git_store
2099 .repositories()
2100 .values()
2101 .find(|repo| {
2102 repo.read(cx)
2103 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2104 .is_some()
2105 })
2106 .cloned()
2107 })
2108 .ok()
2109 .flatten()
2110 .map(|repo| {
2111 repo.update(cx, |repo, _| {
2112 let current_branch =
2113 repo.branch.as_ref().map(|branch| branch.name.to_string());
2114 repo.send_job(None, |state, _| async move {
2115 let RepositoryState::Local { backend, .. } = state else {
2116 return GitState {
2117 remote_url: None,
2118 head_sha: None,
2119 current_branch,
2120 diff: None,
2121 };
2122 };
2123
2124 let remote_url = backend.remote_url("origin");
2125 let head_sha = backend.head_sha().await;
2126 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2127
2128 GitState {
2129 remote_url,
2130 head_sha,
2131 current_branch,
2132 diff,
2133 }
2134 })
2135 })
2136 });
2137
2138 let git_state = match git_state {
2139 Some(git_state) => match git_state.ok() {
2140 Some(git_state) => git_state.await.ok(),
2141 None => None,
2142 },
2143 None => None,
2144 };
2145
2146 WorktreeSnapshot {
2147 worktree_path,
2148 git_state,
2149 }
2150 })
2151 }
2152
2153 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2154 let mut markdown = Vec::new();
2155
2156 if let Some(summary) = self.summary() {
2157 writeln!(markdown, "# {summary}\n")?;
2158 };
2159
2160 for message in self.messages() {
2161 writeln!(
2162 markdown,
2163 "## {role}\n",
2164 role = match message.role {
2165 Role::User => "User",
2166 Role::Assistant => "Assistant",
2167 Role::System => "System",
2168 }
2169 )?;
2170
2171 if !message.loaded_context.text.is_empty() {
2172 writeln!(markdown, "{}", message.loaded_context.text)?;
2173 }
2174
2175 if !message.loaded_context.images.is_empty() {
2176 writeln!(
2177 markdown,
2178 "\n{} images attached as context.\n",
2179 message.loaded_context.images.len()
2180 )?;
2181 }
2182
2183 for segment in &message.segments {
2184 match segment {
2185 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2186 MessageSegment::Thinking { text, .. } => {
2187 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2188 }
2189 MessageSegment::RedactedThinking(_) => {}
2190 }
2191 }
2192
2193 for tool_use in self.tool_uses_for_message(message.id, cx) {
2194 writeln!(
2195 markdown,
2196 "**Use Tool: {} ({})**",
2197 tool_use.name, tool_use.id
2198 )?;
2199 writeln!(markdown, "```json")?;
2200 writeln!(
2201 markdown,
2202 "{}",
2203 serde_json::to_string_pretty(&tool_use.input)?
2204 )?;
2205 writeln!(markdown, "```")?;
2206 }
2207
2208 for tool_result in self.tool_results_for_message(message.id) {
2209 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2210 if tool_result.is_error {
2211 write!(markdown, " (Error)")?;
2212 }
2213
2214 writeln!(markdown, "**\n")?;
2215 writeln!(markdown, "{}", tool_result.content)?;
2216 }
2217 }
2218
2219 Ok(String::from_utf8_lossy(&markdown).to_string())
2220 }
2221
2222 pub fn keep_edits_in_range(
2223 &mut self,
2224 buffer: Entity<language::Buffer>,
2225 buffer_range: Range<language::Anchor>,
2226 cx: &mut Context<Self>,
2227 ) {
2228 self.action_log.update(cx, |action_log, cx| {
2229 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2230 });
2231 }
2232
2233 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2234 self.action_log
2235 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2236 }
2237
2238 pub fn reject_edits_in_ranges(
2239 &mut self,
2240 buffer: Entity<language::Buffer>,
2241 buffer_ranges: Vec<Range<language::Anchor>>,
2242 cx: &mut Context<Self>,
2243 ) -> Task<Result<()>> {
2244 self.action_log.update(cx, |action_log, cx| {
2245 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2246 })
2247 }
2248
2249 pub fn action_log(&self) -> &Entity<ActionLog> {
2250 &self.action_log
2251 }
2252
2253 pub fn project(&self) -> &Entity<Project> {
2254 &self.project
2255 }
2256
2257 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2258 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2259 return;
2260 }
2261
2262 let now = Instant::now();
2263 if let Some(last) = self.last_auto_capture_at {
2264 if now.duration_since(last).as_secs() < 10 {
2265 return;
2266 }
2267 }
2268
2269 self.last_auto_capture_at = Some(now);
2270
2271 let thread_id = self.id().clone();
2272 let github_login = self
2273 .project
2274 .read(cx)
2275 .user_store()
2276 .read(cx)
2277 .current_user()
2278 .map(|user| user.github_login.clone());
2279 let client = self.project.read(cx).client().clone();
2280 let serialize_task = self.serialize(cx);
2281
2282 cx.background_executor()
2283 .spawn(async move {
2284 if let Ok(serialized_thread) = serialize_task.await {
2285 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2286 telemetry::event!(
2287 "Agent Thread Auto-Captured",
2288 thread_id = thread_id.to_string(),
2289 thread_data = thread_data,
2290 auto_capture_reason = "tracked_user",
2291 github_login = github_login
2292 );
2293
2294 client.telemetry().flush_events().await;
2295 }
2296 }
2297 })
2298 .detach();
2299 }
2300
2301 pub fn cumulative_token_usage(&self) -> TokenUsage {
2302 self.cumulative_token_usage
2303 }
2304
2305 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2306 let Some(model) = self.configured_model.as_ref() else {
2307 return TotalTokenUsage::default();
2308 };
2309
2310 let max = model.model.max_token_count();
2311
2312 let index = self
2313 .messages
2314 .iter()
2315 .position(|msg| msg.id == message_id)
2316 .unwrap_or(0);
2317
2318 if index == 0 {
2319 return TotalTokenUsage { total: 0, max };
2320 }
2321
2322 let token_usage = &self
2323 .request_token_usage
2324 .get(index - 1)
2325 .cloned()
2326 .unwrap_or_default();
2327
2328 TotalTokenUsage {
2329 total: token_usage.total_tokens() as usize,
2330 max,
2331 }
2332 }
2333
2334 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2335 let model = self.configured_model.as_ref()?;
2336
2337 let max = model.model.max_token_count();
2338
2339 if let Some(exceeded_error) = &self.exceeded_window_error {
2340 if model.model.id() == exceeded_error.model_id {
2341 return Some(TotalTokenUsage {
2342 total: exceeded_error.token_count,
2343 max,
2344 });
2345 }
2346 }
2347
2348 let total = self
2349 .token_usage_at_last_message()
2350 .unwrap_or_default()
2351 .total_tokens() as usize;
2352
2353 Some(TotalTokenUsage { total, max })
2354 }
2355
2356 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2357 self.request_token_usage
2358 .get(self.messages.len().saturating_sub(1))
2359 .or_else(|| self.request_token_usage.last())
2360 .cloned()
2361 }
2362
2363 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2364 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2365 self.request_token_usage
2366 .resize(self.messages.len(), placeholder);
2367
2368 if let Some(last) = self.request_token_usage.last_mut() {
2369 *last = token_usage;
2370 }
2371 }
2372
2373 pub fn deny_tool_use(
2374 &mut self,
2375 tool_use_id: LanguageModelToolUseId,
2376 tool_name: Arc<str>,
2377 window: Option<AnyWindowHandle>,
2378 cx: &mut Context<Self>,
2379 ) {
2380 let err = Err(anyhow::anyhow!(
2381 "Permission to run tool action denied by user"
2382 ));
2383
2384 self.tool_use.insert_tool_output(
2385 tool_use_id.clone(),
2386 tool_name,
2387 err,
2388 self.configured_model.as_ref(),
2389 );
2390 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2391 }
2392}
2393
2394#[derive(Debug, Clone, Error)]
2395pub enum ThreadError {
2396 #[error("Payment required")]
2397 PaymentRequired,
2398 #[error("Max monthly spend reached")]
2399 MaxMonthlySpendReached,
2400 #[error("Model request limit reached")]
2401 ModelRequestLimitReached { plan: Plan },
2402 #[error("Message {header}: {message}")]
2403 Message {
2404 header: SharedString,
2405 message: SharedString,
2406 },
2407}
2408
2409#[derive(Debug, Clone)]
2410pub enum ThreadEvent {
2411 ShowError(ThreadError),
2412 UsageUpdated(RequestUsage),
2413 StreamedCompletion,
2414 ReceivedTextChunk,
2415 StreamedAssistantText(MessageId, String),
2416 StreamedAssistantThinking(MessageId, String),
2417 StreamedToolUse {
2418 tool_use_id: LanguageModelToolUseId,
2419 ui_text: Arc<str>,
2420 input: serde_json::Value,
2421 },
2422 InvalidToolInput {
2423 tool_use_id: LanguageModelToolUseId,
2424 ui_text: Arc<str>,
2425 invalid_input_json: Arc<str>,
2426 },
2427 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2428 MessageAdded(MessageId),
2429 MessageEdited(MessageId),
2430 MessageDeleted(MessageId),
2431 SummaryGenerated,
2432 SummaryChanged,
2433 UsePendingTools {
2434 tool_uses: Vec<PendingToolUse>,
2435 },
2436 ToolFinished {
2437 #[allow(unused)]
2438 tool_use_id: LanguageModelToolUseId,
2439 /// The pending tool use that corresponds to this tool.
2440 pending_tool_use: Option<PendingToolUse>,
2441 },
2442 CheckpointChanged,
2443 ToolConfirmationNeeded,
2444 CancelEditing,
2445}
2446
2447impl EventEmitter<ThreadEvent> for Thread {}
2448
2449struct PendingCompletion {
2450 id: usize,
2451 _task: Task<()>,
2452}
2453
2454#[cfg(test)]
2455mod tests {
2456 use super::*;
2457 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2458 use assistant_settings::AssistantSettings;
2459 use assistant_tool::ToolRegistry;
2460 use context_server::ContextServerSettings;
2461 use editor::EditorSettings;
2462 use gpui::TestAppContext;
2463 use language_model::fake_provider::FakeLanguageModel;
2464 use project::{FakeFs, Project};
2465 use prompt_store::PromptBuilder;
2466 use serde_json::json;
2467 use settings::{Settings, SettingsStore};
2468 use std::sync::Arc;
2469 use theme::ThemeSettings;
2470 use util::path;
2471 use workspace::Workspace;
2472
2473 #[gpui::test]
2474 async fn test_message_with_context(cx: &mut TestAppContext) {
2475 init_test_settings(cx);
2476
2477 let project = create_test_project(
2478 cx,
2479 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2480 )
2481 .await;
2482
2483 let (_workspace, _thread_store, thread, context_store, model) =
2484 setup_test_environment(cx, project.clone()).await;
2485
2486 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2487 .await
2488 .unwrap();
2489
2490 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2491 let loaded_context = cx
2492 .update(|cx| load_context(vec![context], &project, &None, cx))
2493 .await;
2494
2495 // Insert user message with context
2496 let message_id = thread.update(cx, |thread, cx| {
2497 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2498 });
2499
2500 // Check content and context in message object
2501 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2502
2503 // Use different path format strings based on platform for the test
2504 #[cfg(windows)]
2505 let path_part = r"test\code.rs";
2506 #[cfg(not(windows))]
2507 let path_part = "test/code.rs";
2508
2509 let expected_context = format!(
2510 r#"
2511<context>
2512The following items were attached by the user. They are up-to-date and don't need to be re-read.
2513
2514<files>
2515```rs {path_part}
2516fn main() {{
2517 println!("Hello, world!");
2518}}
2519```
2520</files>
2521</context>
2522"#
2523 );
2524
2525 assert_eq!(message.role, Role::User);
2526 assert_eq!(message.segments.len(), 1);
2527 assert_eq!(
2528 message.segments[0],
2529 MessageSegment::Text("Please explain this code".to_string())
2530 );
2531 assert_eq!(message.loaded_context.text, expected_context);
2532
2533 // Check message in request
2534 let request = thread.update(cx, |thread, cx| {
2535 thread.to_completion_request(model.clone(), cx)
2536 });
2537
2538 assert_eq!(request.messages.len(), 2);
2539 let expected_full_message = format!("{}Please explain this code", expected_context);
2540 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2541 }
2542
2543 #[gpui::test]
2544 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2545 init_test_settings(cx);
2546
2547 let project = create_test_project(
2548 cx,
2549 json!({
2550 "file1.rs": "fn function1() {}\n",
2551 "file2.rs": "fn function2() {}\n",
2552 "file3.rs": "fn function3() {}\n",
2553 "file4.rs": "fn function4() {}\n",
2554 }),
2555 )
2556 .await;
2557
2558 let (_, _thread_store, thread, context_store, model) =
2559 setup_test_environment(cx, project.clone()).await;
2560
2561 // First message with context 1
2562 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2563 .await
2564 .unwrap();
2565 let new_contexts = context_store.update(cx, |store, cx| {
2566 store.new_context_for_thread(thread.read(cx), None)
2567 });
2568 assert_eq!(new_contexts.len(), 1);
2569 let loaded_context = cx
2570 .update(|cx| load_context(new_contexts, &project, &None, cx))
2571 .await;
2572 let message1_id = thread.update(cx, |thread, cx| {
2573 thread.insert_user_message("Message 1", loaded_context, None, cx)
2574 });
2575
2576 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2577 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2578 .await
2579 .unwrap();
2580 let new_contexts = context_store.update(cx, |store, cx| {
2581 store.new_context_for_thread(thread.read(cx), None)
2582 });
2583 assert_eq!(new_contexts.len(), 1);
2584 let loaded_context = cx
2585 .update(|cx| load_context(new_contexts, &project, &None, cx))
2586 .await;
2587 let message2_id = thread.update(cx, |thread, cx| {
2588 thread.insert_user_message("Message 2", loaded_context, None, cx)
2589 });
2590
2591 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2592 //
2593 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2594 .await
2595 .unwrap();
2596 let new_contexts = context_store.update(cx, |store, cx| {
2597 store.new_context_for_thread(thread.read(cx), None)
2598 });
2599 assert_eq!(new_contexts.len(), 1);
2600 let loaded_context = cx
2601 .update(|cx| load_context(new_contexts, &project, &None, cx))
2602 .await;
2603 let message3_id = thread.update(cx, |thread, cx| {
2604 thread.insert_user_message("Message 3", loaded_context, None, cx)
2605 });
2606
2607 // Check what contexts are included in each message
2608 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2609 (
2610 thread.message(message1_id).unwrap().clone(),
2611 thread.message(message2_id).unwrap().clone(),
2612 thread.message(message3_id).unwrap().clone(),
2613 )
2614 });
2615
2616 // First message should include context 1
2617 assert!(message1.loaded_context.text.contains("file1.rs"));
2618
2619 // Second message should include only context 2 (not 1)
2620 assert!(!message2.loaded_context.text.contains("file1.rs"));
2621 assert!(message2.loaded_context.text.contains("file2.rs"));
2622
2623 // Third message should include only context 3 (not 1 or 2)
2624 assert!(!message3.loaded_context.text.contains("file1.rs"));
2625 assert!(!message3.loaded_context.text.contains("file2.rs"));
2626 assert!(message3.loaded_context.text.contains("file3.rs"));
2627
2628 // Check entire request to make sure all contexts are properly included
2629 let request = thread.update(cx, |thread, cx| {
2630 thread.to_completion_request(model.clone(), cx)
2631 });
2632
2633 // The request should contain all 3 messages
2634 assert_eq!(request.messages.len(), 4);
2635
2636 // Check that the contexts are properly formatted in each message
2637 assert!(request.messages[1].string_contents().contains("file1.rs"));
2638 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2639 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2640
2641 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2642 assert!(request.messages[2].string_contents().contains("file2.rs"));
2643 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2644
2645 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2646 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2647 assert!(request.messages[3].string_contents().contains("file3.rs"));
2648
2649 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2650 .await
2651 .unwrap();
2652 let new_contexts = context_store.update(cx, |store, cx| {
2653 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2654 });
2655 assert_eq!(new_contexts.len(), 3);
2656 let loaded_context = cx
2657 .update(|cx| load_context(new_contexts, &project, &None, cx))
2658 .await
2659 .loaded_context;
2660
2661 assert!(!loaded_context.text.contains("file1.rs"));
2662 assert!(loaded_context.text.contains("file2.rs"));
2663 assert!(loaded_context.text.contains("file3.rs"));
2664 assert!(loaded_context.text.contains("file4.rs"));
2665
2666 let new_contexts = context_store.update(cx, |store, cx| {
2667 // Remove file4.rs
2668 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2669 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2670 });
2671 assert_eq!(new_contexts.len(), 2);
2672 let loaded_context = cx
2673 .update(|cx| load_context(new_contexts, &project, &None, cx))
2674 .await
2675 .loaded_context;
2676
2677 assert!(!loaded_context.text.contains("file1.rs"));
2678 assert!(loaded_context.text.contains("file2.rs"));
2679 assert!(loaded_context.text.contains("file3.rs"));
2680 assert!(!loaded_context.text.contains("file4.rs"));
2681
2682 let new_contexts = context_store.update(cx, |store, cx| {
2683 // Remove file3.rs
2684 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2685 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2686 });
2687 assert_eq!(new_contexts.len(), 1);
2688 let loaded_context = cx
2689 .update(|cx| load_context(new_contexts, &project, &None, cx))
2690 .await
2691 .loaded_context;
2692
2693 assert!(!loaded_context.text.contains("file1.rs"));
2694 assert!(loaded_context.text.contains("file2.rs"));
2695 assert!(!loaded_context.text.contains("file3.rs"));
2696 assert!(!loaded_context.text.contains("file4.rs"));
2697 }
2698
2699 #[gpui::test]
2700 async fn test_message_without_files(cx: &mut TestAppContext) {
2701 init_test_settings(cx);
2702
2703 let project = create_test_project(
2704 cx,
2705 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2706 )
2707 .await;
2708
2709 let (_, _thread_store, thread, _context_store, model) =
2710 setup_test_environment(cx, project.clone()).await;
2711
2712 // Insert user message without any context (empty context vector)
2713 let message_id = thread.update(cx, |thread, cx| {
2714 thread.insert_user_message(
2715 "What is the best way to learn Rust?",
2716 ContextLoadResult::default(),
2717 None,
2718 cx,
2719 )
2720 });
2721
2722 // Check content and context in message object
2723 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2724
2725 // Context should be empty when no files are included
2726 assert_eq!(message.role, Role::User);
2727 assert_eq!(message.segments.len(), 1);
2728 assert_eq!(
2729 message.segments[0],
2730 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2731 );
2732 assert_eq!(message.loaded_context.text, "");
2733
2734 // Check message in request
2735 let request = thread.update(cx, |thread, cx| {
2736 thread.to_completion_request(model.clone(), cx)
2737 });
2738
2739 assert_eq!(request.messages.len(), 2);
2740 assert_eq!(
2741 request.messages[1].string_contents(),
2742 "What is the best way to learn Rust?"
2743 );
2744
2745 // Add second message, also without context
2746 let message2_id = thread.update(cx, |thread, cx| {
2747 thread.insert_user_message(
2748 "Are there any good books?",
2749 ContextLoadResult::default(),
2750 None,
2751 cx,
2752 )
2753 });
2754
2755 let message2 =
2756 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2757 assert_eq!(message2.loaded_context.text, "");
2758
2759 // Check that both messages appear in the request
2760 let request = thread.update(cx, |thread, cx| {
2761 thread.to_completion_request(model.clone(), cx)
2762 });
2763
2764 assert_eq!(request.messages.len(), 3);
2765 assert_eq!(
2766 request.messages[1].string_contents(),
2767 "What is the best way to learn Rust?"
2768 );
2769 assert_eq!(
2770 request.messages[2].string_contents(),
2771 "Are there any good books?"
2772 );
2773 }
2774
2775 #[gpui::test]
2776 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2777 init_test_settings(cx);
2778
2779 let project = create_test_project(
2780 cx,
2781 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2782 )
2783 .await;
2784
2785 let (_workspace, _thread_store, thread, context_store, model) =
2786 setup_test_environment(cx, project.clone()).await;
2787
2788 // Open buffer and add it to context
2789 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2790 .await
2791 .unwrap();
2792
2793 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2794 let loaded_context = cx
2795 .update(|cx| load_context(vec![context], &project, &None, cx))
2796 .await;
2797
2798 // Insert user message with the buffer as context
2799 thread.update(cx, |thread, cx| {
2800 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2801 });
2802
2803 // Create a request and check that it doesn't have a stale buffer warning yet
2804 let initial_request = thread.update(cx, |thread, cx| {
2805 thread.to_completion_request(model.clone(), cx)
2806 });
2807
2808 // Make sure we don't have a stale file warning yet
2809 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2810 msg.string_contents()
2811 .contains("These files changed since last read:")
2812 });
2813 assert!(
2814 !has_stale_warning,
2815 "Should not have stale buffer warning before buffer is modified"
2816 );
2817
2818 // Modify the buffer
2819 buffer.update(cx, |buffer, cx| {
2820 // Find a position at the end of line 1
2821 buffer.edit(
2822 [(1..1, "\n println!(\"Added a new line\");\n")],
2823 None,
2824 cx,
2825 );
2826 });
2827
2828 // Insert another user message without context
2829 thread.update(cx, |thread, cx| {
2830 thread.insert_user_message(
2831 "What does the code do now?",
2832 ContextLoadResult::default(),
2833 None,
2834 cx,
2835 )
2836 });
2837
2838 // Create a new request and check for the stale buffer warning
2839 let new_request = thread.update(cx, |thread, cx| {
2840 thread.to_completion_request(model.clone(), cx)
2841 });
2842
2843 // We should have a stale file warning as the last message
2844 let last_message = new_request
2845 .messages
2846 .last()
2847 .expect("Request should have messages");
2848
2849 // The last message should be the stale buffer notification
2850 assert_eq!(last_message.role, Role::User);
2851
2852 // Check the exact content of the message
2853 let expected_content = "These files changed since last read:\n- code.rs\n";
2854 assert_eq!(
2855 last_message.string_contents(),
2856 expected_content,
2857 "Last message should be exactly the stale buffer notification"
2858 );
2859 }
2860
2861 fn init_test_settings(cx: &mut TestAppContext) {
2862 cx.update(|cx| {
2863 let settings_store = SettingsStore::test(cx);
2864 cx.set_global(settings_store);
2865 language::init(cx);
2866 Project::init_settings(cx);
2867 AssistantSettings::register(cx);
2868 prompt_store::init(cx);
2869 thread_store::init(cx);
2870 workspace::init_settings(cx);
2871 language_model::init_settings(cx);
2872 ThemeSettings::register(cx);
2873 ContextServerSettings::register(cx);
2874 EditorSettings::register(cx);
2875 ToolRegistry::default_global(cx);
2876 });
2877 }
2878
2879 // Helper to create a test project with test files
2880 async fn create_test_project(
2881 cx: &mut TestAppContext,
2882 files: serde_json::Value,
2883 ) -> Entity<Project> {
2884 let fs = FakeFs::new(cx.executor());
2885 fs.insert_tree(path!("/test"), files).await;
2886 Project::test(fs, [path!("/test").as_ref()], cx).await
2887 }
2888
2889 async fn setup_test_environment(
2890 cx: &mut TestAppContext,
2891 project: Entity<Project>,
2892 ) -> (
2893 Entity<Workspace>,
2894 Entity<ThreadStore>,
2895 Entity<Thread>,
2896 Entity<ContextStore>,
2897 Arc<dyn LanguageModel>,
2898 ) {
2899 let (workspace, cx) =
2900 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2901
2902 let thread_store = cx
2903 .update(|_, cx| {
2904 ThreadStore::load(
2905 project.clone(),
2906 cx.new(|_| ToolWorkingSet::default()),
2907 None,
2908 Arc::new(PromptBuilder::new(None).unwrap()),
2909 cx,
2910 )
2911 })
2912 .await
2913 .unwrap();
2914
2915 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2916 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2917
2918 let model = FakeLanguageModel::default();
2919 let model: Arc<dyn LanguageModel> = Arc::new(model);
2920
2921 (workspace, thread_store, thread, context_store, model)
2922 }
2923
2924 async fn add_file_to_context(
2925 project: &Entity<Project>,
2926 context_store: &Entity<ContextStore>,
2927 path: &str,
2928 cx: &mut TestAppContext,
2929 ) -> Result<Entity<language::Buffer>> {
2930 let buffer_path = project
2931 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2932 .unwrap();
2933
2934 let buffer = project
2935 .update(cx, |project, cx| {
2936 project.open_buffer(buffer_path.clone(), cx)
2937 })
2938 .await
2939 .unwrap();
2940
2941 context_store.update(cx, |context_store, cx| {
2942 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2943 });
2944
2945 Ok(buffer)
2946 }
2947}