1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
26 TokenUsage,
27};
28use postage::stream::Stream as _;
29use project::Project;
30use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
31use prompt_store::{ModelContext, PromptBuilder};
32use proto::Plan;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use thiserror::Error;
37use util::{ResultExt as _, TryFutureExt as _, post_inc};
38use uuid::Uuid;
39use zed_llm_client::CompletionMode;
40
41use crate::ThreadStore;
42use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
43use crate::thread_store::{
44 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
45 SerializedToolUse, SharedProjectContext,
46};
47use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72/// The ID of the user prompt that initiated a request.
73///
74/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
75#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
76pub struct PromptId(Arc<str>);
77
78impl PromptId {
79 pub fn new() -> Self {
80 Self(Uuid::new_v4().to_string().into())
81 }
82}
83
84impl std::fmt::Display for PromptId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
91pub struct MessageId(pub(crate) usize);
92
93impl MessageId {
94 fn post_inc(&mut self) -> Self {
95 Self(post_inc(&mut self.0))
96 }
97}
98
99/// A message in a [`Thread`].
100#[derive(Debug, Clone)]
101pub struct Message {
102 pub id: MessageId,
103 pub role: Role,
104 pub segments: Vec<MessageSegment>,
105 pub loaded_context: LoadedContext,
106}
107
108impl Message {
109 /// Returns whether the message contains any meaningful text that should be displayed
110 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
111 pub fn should_display_content(&self) -> bool {
112 self.segments.iter().all(|segment| segment.should_display())
113 }
114
115 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
116 if let Some(MessageSegment::Thinking {
117 text: segment,
118 signature: current_signature,
119 }) = self.segments.last_mut()
120 {
121 if let Some(signature) = signature {
122 *current_signature = Some(signature);
123 }
124 segment.push_str(text);
125 } else {
126 self.segments.push(MessageSegment::Thinking {
127 text: text.to_string(),
128 signature,
129 });
130 }
131 }
132
133 pub fn push_text(&mut self, text: &str) {
134 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Text(text.to_string()));
138 }
139 }
140
141 pub fn to_string(&self) -> String {
142 let mut result = String::new();
143
144 if !self.loaded_context.text.is_empty() {
145 result.push_str(&self.loaded_context.text);
146 }
147
148 for segment in &self.segments {
149 match segment {
150 MessageSegment::Text(text) => result.push_str(text),
151 MessageSegment::Thinking { text, .. } => {
152 result.push_str("<think>\n");
153 result.push_str(text);
154 result.push_str("\n</think>");
155 }
156 MessageSegment::RedactedThinking(_) => {}
157 }
158 }
159
160 result
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum MessageSegment {
166 Text(String),
167 Thinking {
168 text: String,
169 signature: Option<String>,
170 },
171 RedactedThinking(Vec<u8>),
172}
173
174impl MessageSegment {
175 pub fn should_display(&self) -> bool {
176 match self {
177 Self::Text(text) => text.is_empty(),
178 Self::Thinking { text, .. } => text.is_empty(),
179 Self::RedactedThinking(_) => false,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ProjectSnapshot {
186 pub worktree_snapshots: Vec<WorktreeSnapshot>,
187 pub unsaved_buffer_paths: Vec<String>,
188 pub timestamp: DateTime<Utc>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct WorktreeSnapshot {
193 pub worktree_path: String,
194 pub git_state: Option<GitState>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct GitState {
199 pub remote_url: Option<String>,
200 pub head_sha: Option<String>,
201 pub current_branch: Option<String>,
202 pub diff: Option<String>,
203}
204
205#[derive(Clone)]
206pub struct ThreadCheckpoint {
207 message_id: MessageId,
208 git_checkpoint: GitStoreCheckpoint,
209}
210
211#[derive(Copy, Clone, Debug, PartialEq, Eq)]
212pub enum ThreadFeedback {
213 Positive,
214 Negative,
215}
216
217pub enum LastRestoreCheckpoint {
218 Pending {
219 message_id: MessageId,
220 },
221 Error {
222 message_id: MessageId,
223 error: String,
224 },
225}
226
227impl LastRestoreCheckpoint {
228 pub fn message_id(&self) -> MessageId {
229 match self {
230 LastRestoreCheckpoint::Pending { message_id } => *message_id,
231 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
232 }
233 }
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub enum DetailedSummaryState {
238 #[default]
239 NotGenerated,
240 Generating {
241 message_id: MessageId,
242 },
243 Generated {
244 text: SharedString,
245 message_id: MessageId,
246 },
247}
248
249impl DetailedSummaryState {
250 fn text(&self) -> Option<SharedString> {
251 if let Self::Generated { text, .. } = self {
252 Some(text.clone())
253 } else {
254 None
255 }
256 }
257}
258
259#[derive(Default)]
260pub struct TotalTokenUsage {
261 pub total: usize,
262 pub max: usize,
263}
264
265impl TotalTokenUsage {
266 pub fn ratio(&self) -> TokenUsageRatio {
267 #[cfg(debug_assertions)]
268 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
269 .unwrap_or("0.8".to_string())
270 .parse()
271 .unwrap();
272 #[cfg(not(debug_assertions))]
273 let warning_threshold: f32 = 0.8;
274
275 if self.total >= self.max {
276 TokenUsageRatio::Exceeded
277 } else if self.total as f32 / self.max as f32 >= warning_threshold {
278 TokenUsageRatio::Warning
279 } else {
280 TokenUsageRatio::Normal
281 }
282 }
283
284 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
285 TotalTokenUsage {
286 total: self.total + tokens,
287 max: self.max,
288 }
289 }
290}
291
292#[derive(Debug, Default, PartialEq, Eq)]
293pub enum TokenUsageRatio {
294 #[default]
295 Normal,
296 Warning,
297 Exceeded,
298}
299
300/// A thread of conversation with the LLM.
301pub struct Thread {
302 id: ThreadId,
303 updated_at: DateTime<Utc>,
304 summary: Option<SharedString>,
305 pending_summary: Task<Option<()>>,
306 detailed_summary_task: Task<Option<()>>,
307 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
308 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
309 completion_mode: Option<CompletionMode>,
310 messages: Vec<Message>,
311 next_message_id: MessageId,
312 last_prompt_id: PromptId,
313 project_context: SharedProjectContext,
314 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
315 completion_count: usize,
316 pending_completions: Vec<PendingCompletion>,
317 project: Entity<Project>,
318 prompt_builder: Arc<PromptBuilder>,
319 tools: Entity<ToolWorkingSet>,
320 tool_use: ToolUseState,
321 action_log: Entity<ActionLog>,
322 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
323 pending_checkpoint: Option<ThreadCheckpoint>,
324 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
325 request_token_usage: Vec<TokenUsage>,
326 cumulative_token_usage: TokenUsage,
327 exceeded_window_error: Option<ExceededWindowError>,
328 feedback: Option<ThreadFeedback>,
329 message_feedback: HashMap<MessageId, ThreadFeedback>,
330 last_auto_capture_at: Option<Instant>,
331 request_callback: Option<
332 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
333 >,
334 remaining_turns: u32,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct ExceededWindowError {
339 /// Model used when last message exceeded context window
340 model_id: LanguageModelId,
341 /// Token count including last message
342 token_count: usize,
343}
344
345impl Thread {
346 pub fn new(
347 project: Entity<Project>,
348 tools: Entity<ToolWorkingSet>,
349 prompt_builder: Arc<PromptBuilder>,
350 system_prompt: SharedProjectContext,
351 cx: &mut Context<Self>,
352 ) -> Self {
353 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
354 Self {
355 id: ThreadId::new(),
356 updated_at: Utc::now(),
357 summary: None,
358 pending_summary: Task::ready(None),
359 detailed_summary_task: Task::ready(None),
360 detailed_summary_tx,
361 detailed_summary_rx,
362 completion_mode: None,
363 messages: Vec::new(),
364 next_message_id: MessageId(0),
365 last_prompt_id: PromptId::new(),
366 project_context: system_prompt,
367 checkpoints_by_message: HashMap::default(),
368 completion_count: 0,
369 pending_completions: Vec::new(),
370 project: project.clone(),
371 prompt_builder,
372 tools: tools.clone(),
373 last_restore_checkpoint: None,
374 pending_checkpoint: None,
375 tool_use: ToolUseState::new(tools.clone()),
376 action_log: cx.new(|_| ActionLog::new(project.clone())),
377 initial_project_snapshot: {
378 let project_snapshot = Self::project_snapshot(project, cx);
379 cx.foreground_executor()
380 .spawn(async move { Some(project_snapshot.await) })
381 .shared()
382 },
383 request_token_usage: Vec::new(),
384 cumulative_token_usage: TokenUsage::default(),
385 exceeded_window_error: None,
386 feedback: None,
387 message_feedback: HashMap::default(),
388 last_auto_capture_at: None,
389 request_callback: None,
390 remaining_turns: u32::MAX,
391 }
392 }
393
394 pub fn deserialize(
395 id: ThreadId,
396 serialized: SerializedThread,
397 project: Entity<Project>,
398 tools: Entity<ToolWorkingSet>,
399 prompt_builder: Arc<PromptBuilder>,
400 project_context: SharedProjectContext,
401 cx: &mut Context<Self>,
402 ) -> Self {
403 let next_message_id = MessageId(
404 serialized
405 .messages
406 .last()
407 .map(|message| message.id.0 + 1)
408 .unwrap_or(0),
409 );
410 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
411 let (detailed_summary_tx, detailed_summary_rx) =
412 postage::watch::channel_with(serialized.detailed_summary_state);
413
414 Self {
415 id,
416 updated_at: serialized.updated_at,
417 summary: Some(serialized.summary),
418 pending_summary: Task::ready(None),
419 detailed_summary_task: Task::ready(None),
420 detailed_summary_tx,
421 detailed_summary_rx,
422 completion_mode: None,
423 messages: serialized
424 .messages
425 .into_iter()
426 .map(|message| Message {
427 id: message.id,
428 role: message.role,
429 segments: message
430 .segments
431 .into_iter()
432 .map(|segment| match segment {
433 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
434 SerializedMessageSegment::Thinking { text, signature } => {
435 MessageSegment::Thinking { text, signature }
436 }
437 SerializedMessageSegment::RedactedThinking { data } => {
438 MessageSegment::RedactedThinking(data)
439 }
440 })
441 .collect(),
442 loaded_context: LoadedContext {
443 contexts: Vec::new(),
444 text: message.context,
445 images: Vec::new(),
446 },
447 })
448 .collect(),
449 next_message_id,
450 last_prompt_id: PromptId::new(),
451 project_context,
452 checkpoints_by_message: HashMap::default(),
453 completion_count: 0,
454 pending_completions: Vec::new(),
455 last_restore_checkpoint: None,
456 pending_checkpoint: None,
457 project: project.clone(),
458 prompt_builder,
459 tools,
460 tool_use,
461 action_log: cx.new(|_| ActionLog::new(project)),
462 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
463 request_token_usage: serialized.request_token_usage,
464 cumulative_token_usage: serialized.cumulative_token_usage,
465 exceeded_window_error: None,
466 feedback: None,
467 message_feedback: HashMap::default(),
468 last_auto_capture_at: None,
469 request_callback: None,
470 remaining_turns: u32::MAX,
471 }
472 }
473
474 pub fn set_request_callback(
475 &mut self,
476 callback: impl 'static
477 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
478 ) {
479 self.request_callback = Some(Box::new(callback));
480 }
481
482 pub fn id(&self) -> &ThreadId {
483 &self.id
484 }
485
486 pub fn is_empty(&self) -> bool {
487 self.messages.is_empty()
488 }
489
490 pub fn updated_at(&self) -> DateTime<Utc> {
491 self.updated_at
492 }
493
494 pub fn touch_updated_at(&mut self) {
495 self.updated_at = Utc::now();
496 }
497
498 pub fn advance_prompt_id(&mut self) {
499 self.last_prompt_id = PromptId::new();
500 }
501
502 pub fn summary(&self) -> Option<SharedString> {
503 self.summary.clone()
504 }
505
506 pub fn project_context(&self) -> SharedProjectContext {
507 self.project_context.clone()
508 }
509
510 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
511
512 pub fn summary_or_default(&self) -> SharedString {
513 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
514 }
515
516 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
517 let Some(current_summary) = &self.summary else {
518 // Don't allow setting summary until generated
519 return;
520 };
521
522 let mut new_summary = new_summary.into();
523
524 if new_summary.is_empty() {
525 new_summary = Self::DEFAULT_SUMMARY;
526 }
527
528 if current_summary != &new_summary {
529 self.summary = Some(new_summary);
530 cx.emit(ThreadEvent::SummaryChanged);
531 }
532 }
533
534 pub fn completion_mode(&self) -> Option<CompletionMode> {
535 self.completion_mode
536 }
537
538 pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
539 self.completion_mode = mode;
540 }
541
542 pub fn message(&self, id: MessageId) -> Option<&Message> {
543 let index = self
544 .messages
545 .binary_search_by(|message| message.id.cmp(&id))
546 .ok()?;
547
548 self.messages.get(index)
549 }
550
551 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
552 self.messages.iter()
553 }
554
555 pub fn is_generating(&self) -> bool {
556 !self.pending_completions.is_empty() || !self.all_tools_finished()
557 }
558
559 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
560 &self.tools
561 }
562
563 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
564 self.tool_use
565 .pending_tool_uses()
566 .into_iter()
567 .find(|tool_use| &tool_use.id == id)
568 }
569
570 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
571 self.tool_use
572 .pending_tool_uses()
573 .into_iter()
574 .filter(|tool_use| tool_use.status.needs_confirmation())
575 }
576
577 pub fn has_pending_tool_uses(&self) -> bool {
578 !self.tool_use.pending_tool_uses().is_empty()
579 }
580
581 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
582 self.checkpoints_by_message.get(&id).cloned()
583 }
584
585 pub fn restore_checkpoint(
586 &mut self,
587 checkpoint: ThreadCheckpoint,
588 cx: &mut Context<Self>,
589 ) -> Task<Result<()>> {
590 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
591 message_id: checkpoint.message_id,
592 });
593 cx.emit(ThreadEvent::CheckpointChanged);
594 cx.notify();
595
596 let git_store = self.project().read(cx).git_store().clone();
597 let restore = git_store.update(cx, |git_store, cx| {
598 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
599 });
600
601 cx.spawn(async move |this, cx| {
602 let result = restore.await;
603 this.update(cx, |this, cx| {
604 if let Err(err) = result.as_ref() {
605 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
606 message_id: checkpoint.message_id,
607 error: err.to_string(),
608 });
609 } else {
610 this.truncate(checkpoint.message_id, cx);
611 this.last_restore_checkpoint = None;
612 }
613 this.pending_checkpoint = None;
614 cx.emit(ThreadEvent::CheckpointChanged);
615 cx.notify();
616 })?;
617 result
618 })
619 }
620
621 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
622 let pending_checkpoint = if self.is_generating() {
623 return;
624 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
625 checkpoint
626 } else {
627 return;
628 };
629
630 let git_store = self.project.read(cx).git_store().clone();
631 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
632 cx.spawn(async move |this, cx| match final_checkpoint.await {
633 Ok(final_checkpoint) => {
634 let equal = git_store
635 .update(cx, |store, cx| {
636 store.compare_checkpoints(
637 pending_checkpoint.git_checkpoint.clone(),
638 final_checkpoint.clone(),
639 cx,
640 )
641 })?
642 .await
643 .unwrap_or(false);
644
645 if !equal {
646 this.update(cx, |this, cx| {
647 this.insert_checkpoint(pending_checkpoint, cx)
648 })?;
649 }
650
651 Ok(())
652 }
653 Err(_) => this.update(cx, |this, cx| {
654 this.insert_checkpoint(pending_checkpoint, cx)
655 }),
656 })
657 .detach();
658 }
659
660 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
661 self.checkpoints_by_message
662 .insert(checkpoint.message_id, checkpoint);
663 cx.emit(ThreadEvent::CheckpointChanged);
664 cx.notify();
665 }
666
667 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
668 self.last_restore_checkpoint.as_ref()
669 }
670
671 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
672 let Some(message_ix) = self
673 .messages
674 .iter()
675 .rposition(|message| message.id == message_id)
676 else {
677 return;
678 };
679 for deleted_message in self.messages.drain(message_ix..) {
680 self.checkpoints_by_message.remove(&deleted_message.id);
681 }
682 cx.notify();
683 }
684
685 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
686 self.messages
687 .iter()
688 .find(|message| message.id == id)
689 .into_iter()
690 .flat_map(|message| message.loaded_context.contexts.iter())
691 }
692
693 pub fn is_turn_end(&self, ix: usize) -> bool {
694 if self.messages.is_empty() {
695 return false;
696 }
697
698 if !self.is_generating() && ix == self.messages.len() - 1 {
699 return true;
700 }
701
702 let Some(message) = self.messages.get(ix) else {
703 return false;
704 };
705
706 if message.role != Role::Assistant {
707 return false;
708 }
709
710 self.messages
711 .get(ix + 1)
712 .and_then(|message| {
713 self.message(message.id)
714 .map(|next_message| next_message.role == Role::User)
715 })
716 .unwrap_or(false)
717 }
718
719 /// Returns whether all of the tool uses have finished running.
720 pub fn all_tools_finished(&self) -> bool {
721 // If the only pending tool uses left are the ones with errors, then
722 // that means that we've finished running all of the pending tools.
723 self.tool_use
724 .pending_tool_uses()
725 .iter()
726 .all(|tool_use| tool_use.status.is_error())
727 }
728
729 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
730 self.tool_use.tool_uses_for_message(id, cx)
731 }
732
733 pub fn tool_results_for_message(
734 &self,
735 assistant_message_id: MessageId,
736 ) -> Vec<&LanguageModelToolResult> {
737 self.tool_use.tool_results_for_message(assistant_message_id)
738 }
739
740 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
741 self.tool_use.tool_result(id)
742 }
743
744 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
745 Some(&self.tool_use.tool_result(id)?.content)
746 }
747
748 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
749 self.tool_use.tool_result_card(id).cloned()
750 }
751
752 /// Return tools that are both enabled and supported by the model
753 pub fn available_tools(
754 &self,
755 cx: &App,
756 model: Arc<dyn LanguageModel>,
757 ) -> Vec<LanguageModelRequestTool> {
758 if model.supports_tools() {
759 self.tools()
760 .read(cx)
761 .enabled_tools(cx)
762 .into_iter()
763 .filter_map(|tool| {
764 // Skip tools that cannot be supported
765 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
766 Some(LanguageModelRequestTool {
767 name: tool.name(),
768 description: tool.description(),
769 input_schema,
770 })
771 })
772 .collect()
773 } else {
774 Vec::default()
775 }
776 }
777
778 pub fn insert_user_message(
779 &mut self,
780 text: impl Into<String>,
781 loaded_context: ContextLoadResult,
782 git_checkpoint: Option<GitStoreCheckpoint>,
783 cx: &mut Context<Self>,
784 ) -> MessageId {
785 if !loaded_context.referenced_buffers.is_empty() {
786 self.action_log.update(cx, |log, cx| {
787 for buffer in loaded_context.referenced_buffers {
788 log.track_buffer(buffer, cx);
789 }
790 });
791 }
792
793 let message_id = self.insert_message(
794 Role::User,
795 vec![MessageSegment::Text(text.into())],
796 loaded_context.loaded_context,
797 cx,
798 );
799
800 if let Some(git_checkpoint) = git_checkpoint {
801 self.pending_checkpoint = Some(ThreadCheckpoint {
802 message_id,
803 git_checkpoint,
804 });
805 }
806
807 self.auto_capture_telemetry(cx);
808
809 message_id
810 }
811
812 pub fn insert_assistant_message(
813 &mut self,
814 segments: Vec<MessageSegment>,
815 cx: &mut Context<Self>,
816 ) -> MessageId {
817 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
818 }
819
820 pub fn insert_message(
821 &mut self,
822 role: Role,
823 segments: Vec<MessageSegment>,
824 loaded_context: LoadedContext,
825 cx: &mut Context<Self>,
826 ) -> MessageId {
827 let id = self.next_message_id.post_inc();
828 self.messages.push(Message {
829 id,
830 role,
831 segments,
832 loaded_context,
833 });
834 self.touch_updated_at();
835 cx.emit(ThreadEvent::MessageAdded(id));
836 id
837 }
838
839 pub fn edit_message(
840 &mut self,
841 id: MessageId,
842 new_role: Role,
843 new_segments: Vec<MessageSegment>,
844 cx: &mut Context<Self>,
845 ) -> bool {
846 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
847 return false;
848 };
849 message.role = new_role;
850 message.segments = new_segments;
851 self.touch_updated_at();
852 cx.emit(ThreadEvent::MessageEdited(id));
853 true
854 }
855
856 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
857 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
858 return false;
859 };
860 self.messages.remove(index);
861 self.touch_updated_at();
862 cx.emit(ThreadEvent::MessageDeleted(id));
863 true
864 }
865
866 /// Returns the representation of this [`Thread`] in a textual form.
867 ///
868 /// This is the representation we use when attaching a thread as context to another thread.
869 pub fn text(&self) -> String {
870 let mut text = String::new();
871
872 for message in &self.messages {
873 text.push_str(match message.role {
874 language_model::Role::User => "User:",
875 language_model::Role::Assistant => "Assistant:",
876 language_model::Role::System => "System:",
877 });
878 text.push('\n');
879
880 for segment in &message.segments {
881 match segment {
882 MessageSegment::Text(content) => text.push_str(content),
883 MessageSegment::Thinking { text: content, .. } => {
884 text.push_str(&format!("<think>{}</think>", content))
885 }
886 MessageSegment::RedactedThinking(_) => {}
887 }
888 }
889 text.push('\n');
890 }
891
892 text
893 }
894
895 /// Serializes this thread into a format for storage or telemetry.
896 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
897 let initial_project_snapshot = self.initial_project_snapshot.clone();
898 cx.spawn(async move |this, cx| {
899 let initial_project_snapshot = initial_project_snapshot.await;
900 this.read_with(cx, |this, cx| SerializedThread {
901 version: SerializedThread::VERSION.to_string(),
902 summary: this.summary_or_default(),
903 updated_at: this.updated_at(),
904 messages: this
905 .messages()
906 .map(|message| SerializedMessage {
907 id: message.id,
908 role: message.role,
909 segments: message
910 .segments
911 .iter()
912 .map(|segment| match segment {
913 MessageSegment::Text(text) => {
914 SerializedMessageSegment::Text { text: text.clone() }
915 }
916 MessageSegment::Thinking { text, signature } => {
917 SerializedMessageSegment::Thinking {
918 text: text.clone(),
919 signature: signature.clone(),
920 }
921 }
922 MessageSegment::RedactedThinking(data) => {
923 SerializedMessageSegment::RedactedThinking {
924 data: data.clone(),
925 }
926 }
927 })
928 .collect(),
929 tool_uses: this
930 .tool_uses_for_message(message.id, cx)
931 .into_iter()
932 .map(|tool_use| SerializedToolUse {
933 id: tool_use.id,
934 name: tool_use.name,
935 input: tool_use.input,
936 })
937 .collect(),
938 tool_results: this
939 .tool_results_for_message(message.id)
940 .into_iter()
941 .map(|tool_result| SerializedToolResult {
942 tool_use_id: tool_result.tool_use_id.clone(),
943 is_error: tool_result.is_error,
944 content: tool_result.content.clone(),
945 })
946 .collect(),
947 context: message.loaded_context.text.clone(),
948 })
949 .collect(),
950 initial_project_snapshot,
951 cumulative_token_usage: this.cumulative_token_usage,
952 request_token_usage: this.request_token_usage.clone(),
953 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
954 exceeded_window_error: this.exceeded_window_error.clone(),
955 })
956 })
957 }
958
959 pub fn remaining_turns(&self) -> u32 {
960 self.remaining_turns
961 }
962
963 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
964 self.remaining_turns = remaining_turns;
965 }
966
967 pub fn send_to_model(
968 &mut self,
969 model: Arc<dyn LanguageModel>,
970 window: Option<AnyWindowHandle>,
971 cx: &mut Context<Self>,
972 ) {
973 if self.remaining_turns == 0 {
974 return;
975 }
976
977 self.remaining_turns -= 1;
978
979 let request = self.to_completion_request(model.clone(), cx);
980
981 self.stream_completion(request, model, window, cx);
982 }
983
984 pub fn used_tools_since_last_user_message(&self) -> bool {
985 for message in self.messages.iter().rev() {
986 if self.tool_use.message_has_tool_results(message.id) {
987 return true;
988 } else if message.role == Role::User {
989 return false;
990 }
991 }
992
993 false
994 }
995
996 pub fn to_completion_request(
997 &self,
998 model: Arc<dyn LanguageModel>,
999 cx: &mut Context<Self>,
1000 ) -> LanguageModelRequest {
1001 let mut request = LanguageModelRequest {
1002 thread_id: Some(self.id.to_string()),
1003 prompt_id: Some(self.last_prompt_id.to_string()),
1004 mode: None,
1005 messages: vec![],
1006 tools: Vec::new(),
1007 stop: Vec::new(),
1008 temperature: None,
1009 };
1010
1011 let available_tools = self.available_tools(cx, model.clone());
1012 let available_tool_names = available_tools
1013 .iter()
1014 .map(|tool| tool.name.clone())
1015 .collect();
1016
1017 let model_context = &ModelContext {
1018 available_tools: available_tool_names,
1019 };
1020
1021 if let Some(project_context) = self.project_context.borrow().as_ref() {
1022 match self
1023 .prompt_builder
1024 .generate_assistant_system_prompt(project_context, model_context)
1025 {
1026 Err(err) => {
1027 let message = format!("{err:?}").into();
1028 log::error!("{message}");
1029 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1030 header: "Error generating system prompt".into(),
1031 message,
1032 }));
1033 }
1034 Ok(system_prompt) => {
1035 request.messages.push(LanguageModelRequestMessage {
1036 role: Role::System,
1037 content: vec![MessageContent::Text(system_prompt)],
1038 cache: true,
1039 });
1040 }
1041 }
1042 } else {
1043 let message = "Context for system prompt unexpectedly not ready.".into();
1044 log::error!("{message}");
1045 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1046 header: "Error generating system prompt".into(),
1047 message,
1048 }));
1049 }
1050
1051 for message in &self.messages {
1052 let mut request_message = LanguageModelRequestMessage {
1053 role: message.role,
1054 content: Vec::new(),
1055 cache: false,
1056 };
1057
1058 message
1059 .loaded_context
1060 .add_to_request_message(&mut request_message);
1061
1062 for segment in &message.segments {
1063 match segment {
1064 MessageSegment::Text(text) => {
1065 if !text.is_empty() {
1066 request_message
1067 .content
1068 .push(MessageContent::Text(text.into()));
1069 }
1070 }
1071 MessageSegment::Thinking { text, signature } => {
1072 if !text.is_empty() {
1073 request_message.content.push(MessageContent::Thinking {
1074 text: text.into(),
1075 signature: signature.clone(),
1076 });
1077 }
1078 }
1079 MessageSegment::RedactedThinking(data) => {
1080 request_message
1081 .content
1082 .push(MessageContent::RedactedThinking(data.clone()));
1083 }
1084 };
1085 }
1086
1087 self.tool_use
1088 .attach_tool_uses(message.id, &mut request_message);
1089
1090 request.messages.push(request_message);
1091
1092 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1093 request.messages.push(tool_results_message);
1094 }
1095 }
1096
1097 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1098 if let Some(last) = request.messages.last_mut() {
1099 last.cache = true;
1100 }
1101
1102 self.attached_tracked_files_state(&mut request.messages, cx);
1103
1104 request.tools = available_tools;
1105 request.mode = if model.supports_max_mode() {
1106 self.completion_mode
1107 } else {
1108 None
1109 };
1110
1111 request
1112 }
1113
1114 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1115 let mut request = LanguageModelRequest {
1116 thread_id: None,
1117 prompt_id: None,
1118 mode: None,
1119 messages: vec![],
1120 tools: Vec::new(),
1121 stop: Vec::new(),
1122 temperature: None,
1123 };
1124
1125 for message in &self.messages {
1126 let mut request_message = LanguageModelRequestMessage {
1127 role: message.role,
1128 content: Vec::new(),
1129 cache: false,
1130 };
1131
1132 for segment in &message.segments {
1133 match segment {
1134 MessageSegment::Text(text) => request_message
1135 .content
1136 .push(MessageContent::Text(text.clone())),
1137 MessageSegment::Thinking { .. } => {}
1138 MessageSegment::RedactedThinking(_) => {}
1139 }
1140 }
1141
1142 if request_message.content.is_empty() {
1143 continue;
1144 }
1145
1146 request.messages.push(request_message);
1147 }
1148
1149 request.messages.push(LanguageModelRequestMessage {
1150 role: Role::User,
1151 content: vec![MessageContent::Text(added_user_message)],
1152 cache: false,
1153 });
1154
1155 request
1156 }
1157
1158 fn attached_tracked_files_state(
1159 &self,
1160 messages: &mut Vec<LanguageModelRequestMessage>,
1161 cx: &App,
1162 ) {
1163 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1164
1165 let mut stale_message = String::new();
1166
1167 let action_log = self.action_log.read(cx);
1168
1169 for stale_file in action_log.stale_buffers(cx) {
1170 let Some(file) = stale_file.read(cx).file() else {
1171 continue;
1172 };
1173
1174 if stale_message.is_empty() {
1175 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1176 }
1177
1178 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1179 }
1180
1181 let mut content = Vec::with_capacity(2);
1182
1183 if !stale_message.is_empty() {
1184 content.push(stale_message.into());
1185 }
1186
1187 if !content.is_empty() {
1188 let context_message = LanguageModelRequestMessage {
1189 role: Role::User,
1190 content,
1191 cache: false,
1192 };
1193
1194 messages.push(context_message);
1195 }
1196 }
1197
1198 pub fn stream_completion(
1199 &mut self,
1200 request: LanguageModelRequest,
1201 model: Arc<dyn LanguageModel>,
1202 window: Option<AnyWindowHandle>,
1203 cx: &mut Context<Self>,
1204 ) {
1205 let pending_completion_id = post_inc(&mut self.completion_count);
1206 let mut request_callback_parameters = if self.request_callback.is_some() {
1207 Some((request.clone(), Vec::new()))
1208 } else {
1209 None
1210 };
1211 let prompt_id = self.last_prompt_id.clone();
1212 let tool_use_metadata = ToolUseMetadata {
1213 model: model.clone(),
1214 thread_id: self.id.clone(),
1215 prompt_id: prompt_id.clone(),
1216 };
1217
1218 let task = cx.spawn(async move |thread, cx| {
1219 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1220 let initial_token_usage =
1221 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1222 let stream_completion = async {
1223 let (mut events, usage) = stream_completion_future.await?;
1224
1225 let mut stop_reason = StopReason::EndTurn;
1226 let mut current_token_usage = TokenUsage::default();
1227
1228 if let Some(usage) = usage {
1229 thread
1230 .update(cx, |_thread, cx| {
1231 cx.emit(ThreadEvent::UsageUpdated(usage));
1232 })
1233 .ok();
1234 }
1235
1236 let mut request_assistant_message_id = None;
1237
1238 while let Some(event) = events.next().await {
1239 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1240 response_events
1241 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1242 }
1243
1244 thread.update(cx, |thread, cx| {
1245 let event = match event {
1246 Ok(event) => event,
1247 Err(LanguageModelCompletionError::BadInputJson {
1248 id,
1249 tool_name,
1250 raw_input: invalid_input_json,
1251 json_parse_error,
1252 }) => {
1253 thread.receive_invalid_tool_json(
1254 id,
1255 tool_name,
1256 invalid_input_json,
1257 json_parse_error,
1258 window,
1259 cx,
1260 );
1261 return Ok(());
1262 }
1263 Err(LanguageModelCompletionError::Other(error)) => {
1264 return Err(error);
1265 }
1266 };
1267
1268 match event {
1269 LanguageModelCompletionEvent::StartMessage { .. } => {
1270 request_assistant_message_id =
1271 Some(thread.insert_assistant_message(
1272 vec![MessageSegment::Text(String::new())],
1273 cx,
1274 ));
1275 }
1276 LanguageModelCompletionEvent::Stop(reason) => {
1277 stop_reason = reason;
1278 }
1279 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1280 thread.update_token_usage_at_last_message(token_usage);
1281 thread.cumulative_token_usage = thread.cumulative_token_usage
1282 + token_usage
1283 - current_token_usage;
1284 current_token_usage = token_usage;
1285 }
1286 LanguageModelCompletionEvent::Text(chunk) => {
1287 cx.emit(ThreadEvent::ReceivedTextChunk);
1288 if let Some(last_message) = thread.messages.last_mut() {
1289 if last_message.role == Role::Assistant
1290 && !thread.tool_use.has_tool_results(last_message.id)
1291 {
1292 last_message.push_text(&chunk);
1293 cx.emit(ThreadEvent::StreamedAssistantText(
1294 last_message.id,
1295 chunk,
1296 ));
1297 } else {
1298 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1299 // of a new Assistant response.
1300 //
1301 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1302 // will result in duplicating the text of the chunk in the rendered Markdown.
1303 request_assistant_message_id =
1304 Some(thread.insert_assistant_message(
1305 vec![MessageSegment::Text(chunk.to_string())],
1306 cx,
1307 ));
1308 };
1309 }
1310 }
1311 LanguageModelCompletionEvent::Thinking {
1312 text: chunk,
1313 signature,
1314 } => {
1315 if let Some(last_message) = thread.messages.last_mut() {
1316 if last_message.role == Role::Assistant
1317 && !thread.tool_use.has_tool_results(last_message.id)
1318 {
1319 last_message.push_thinking(&chunk, signature);
1320 cx.emit(ThreadEvent::StreamedAssistantThinking(
1321 last_message.id,
1322 chunk,
1323 ));
1324 } else {
1325 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1326 // of a new Assistant response.
1327 //
1328 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1329 // will result in duplicating the text of the chunk in the rendered Markdown.
1330 request_assistant_message_id =
1331 Some(thread.insert_assistant_message(
1332 vec![MessageSegment::Thinking {
1333 text: chunk.to_string(),
1334 signature,
1335 }],
1336 cx,
1337 ));
1338 };
1339 }
1340 }
1341 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1342 let last_assistant_message_id = request_assistant_message_id
1343 .unwrap_or_else(|| {
1344 let new_assistant_message_id =
1345 thread.insert_assistant_message(vec![], cx);
1346 request_assistant_message_id =
1347 Some(new_assistant_message_id);
1348 new_assistant_message_id
1349 });
1350
1351 let tool_use_id = tool_use.id.clone();
1352 let streamed_input = if tool_use.is_input_complete {
1353 None
1354 } else {
1355 Some((&tool_use.input).clone())
1356 };
1357
1358 let ui_text = thread.tool_use.request_tool_use(
1359 last_assistant_message_id,
1360 tool_use,
1361 tool_use_metadata.clone(),
1362 cx,
1363 );
1364
1365 if let Some(input) = streamed_input {
1366 cx.emit(ThreadEvent::StreamedToolUse {
1367 tool_use_id,
1368 ui_text,
1369 input,
1370 });
1371 }
1372 }
1373 }
1374
1375 thread.touch_updated_at();
1376 cx.emit(ThreadEvent::StreamedCompletion);
1377 cx.notify();
1378
1379 thread.auto_capture_telemetry(cx);
1380 Ok(())
1381 })??;
1382
1383 smol::future::yield_now().await;
1384 }
1385
1386 thread.update(cx, |thread, cx| {
1387 thread
1388 .pending_completions
1389 .retain(|completion| completion.id != pending_completion_id);
1390
1391 // If there is a response without tool use, summarize the message. Otherwise,
1392 // allow two tool uses before summarizing.
1393 if thread.summary.is_none()
1394 && thread.messages.len() >= 2
1395 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1396 {
1397 thread.summarize(cx);
1398 }
1399 })?;
1400
1401 anyhow::Ok(stop_reason)
1402 };
1403
1404 let result = stream_completion.await;
1405
1406 thread
1407 .update(cx, |thread, cx| {
1408 thread.finalize_pending_checkpoint(cx);
1409 match result.as_ref() {
1410 Ok(stop_reason) => match stop_reason {
1411 StopReason::ToolUse => {
1412 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1413 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1414 }
1415 StopReason::EndTurn => {}
1416 StopReason::MaxTokens => {}
1417 },
1418 Err(error) => {
1419 if error.is::<PaymentRequiredError>() {
1420 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1421 } else if error.is::<MaxMonthlySpendReachedError>() {
1422 cx.emit(ThreadEvent::ShowError(
1423 ThreadError::MaxMonthlySpendReached,
1424 ));
1425 } else if let Some(error) =
1426 error.downcast_ref::<ModelRequestLimitReachedError>()
1427 {
1428 cx.emit(ThreadEvent::ShowError(
1429 ThreadError::ModelRequestLimitReached { plan: error.plan },
1430 ));
1431 } else if let Some(known_error) =
1432 error.downcast_ref::<LanguageModelKnownError>()
1433 {
1434 match known_error {
1435 LanguageModelKnownError::ContextWindowLimitExceeded {
1436 tokens,
1437 } => {
1438 thread.exceeded_window_error = Some(ExceededWindowError {
1439 model_id: model.id(),
1440 token_count: *tokens,
1441 });
1442 cx.notify();
1443 }
1444 }
1445 } else {
1446 let error_message = error
1447 .chain()
1448 .map(|err| err.to_string())
1449 .collect::<Vec<_>>()
1450 .join("\n");
1451 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1452 header: "Error interacting with language model".into(),
1453 message: SharedString::from(error_message.clone()),
1454 }));
1455 }
1456
1457 thread.cancel_last_completion(window, cx);
1458 }
1459 }
1460 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1461
1462 if let Some((request_callback, (request, response_events))) = thread
1463 .request_callback
1464 .as_mut()
1465 .zip(request_callback_parameters.as_ref())
1466 {
1467 request_callback(request, response_events);
1468 }
1469
1470 thread.auto_capture_telemetry(cx);
1471
1472 if let Ok(initial_usage) = initial_token_usage {
1473 let usage = thread.cumulative_token_usage - initial_usage;
1474
1475 telemetry::event!(
1476 "Assistant Thread Completion",
1477 thread_id = thread.id().to_string(),
1478 prompt_id = prompt_id,
1479 model = model.telemetry_id(),
1480 model_provider = model.provider_id().to_string(),
1481 input_tokens = usage.input_tokens,
1482 output_tokens = usage.output_tokens,
1483 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1484 cache_read_input_tokens = usage.cache_read_input_tokens,
1485 );
1486 }
1487 })
1488 .ok();
1489 });
1490
1491 self.pending_completions.push(PendingCompletion {
1492 id: pending_completion_id,
1493 _task: task,
1494 });
1495 }
1496
1497 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1498 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1499 return;
1500 };
1501
1502 if !model.provider.is_authenticated(cx) {
1503 return;
1504 }
1505
1506 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1507 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1508 If the conversation is about a specific subject, include it in the title. \
1509 Be descriptive. DO NOT speak in the first person.";
1510
1511 let request = self.to_summarize_request(added_user_message.into());
1512
1513 self.pending_summary = cx.spawn(async move |this, cx| {
1514 async move {
1515 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1516 let (mut messages, usage) = stream.await?;
1517
1518 if let Some(usage) = usage {
1519 this.update(cx, |_thread, cx| {
1520 cx.emit(ThreadEvent::UsageUpdated(usage));
1521 })
1522 .ok();
1523 }
1524
1525 let mut new_summary = String::new();
1526 while let Some(message) = messages.stream.next().await {
1527 let text = message?;
1528 let mut lines = text.lines();
1529 new_summary.extend(lines.next());
1530
1531 // Stop if the LLM generated multiple lines.
1532 if lines.next().is_some() {
1533 break;
1534 }
1535 }
1536
1537 this.update(cx, |this, cx| {
1538 if !new_summary.is_empty() {
1539 this.summary = Some(new_summary.into());
1540 }
1541
1542 cx.emit(ThreadEvent::SummaryGenerated);
1543 })?;
1544
1545 anyhow::Ok(())
1546 }
1547 .log_err()
1548 .await
1549 });
1550 }
1551
1552 pub fn start_generating_detailed_summary_if_needed(
1553 &mut self,
1554 thread_store: WeakEntity<ThreadStore>,
1555 cx: &mut Context<Self>,
1556 ) {
1557 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1558 return;
1559 };
1560
1561 match &*self.detailed_summary_rx.borrow() {
1562 DetailedSummaryState::Generating { message_id, .. }
1563 | DetailedSummaryState::Generated { message_id, .. }
1564 if *message_id == last_message_id =>
1565 {
1566 // Already up-to-date
1567 return;
1568 }
1569 _ => {}
1570 }
1571
1572 let Some(ConfiguredModel { model, provider }) =
1573 LanguageModelRegistry::read_global(cx).thread_summary_model()
1574 else {
1575 return;
1576 };
1577
1578 if !provider.is_authenticated(cx) {
1579 return;
1580 }
1581
1582 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1583 1. A brief overview of what was discussed\n\
1584 2. Key facts or information discovered\n\
1585 3. Outcomes or conclusions reached\n\
1586 4. Any action items or next steps if any\n\
1587 Format it in Markdown with headings and bullet points.";
1588
1589 let request = self.to_summarize_request(added_user_message.into());
1590
1591 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1592 message_id: last_message_id,
1593 };
1594
1595 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1596 // be better to allow the old task to complete, but this would require logic for choosing
1597 // which result to prefer (the old task could complete after the new one, resulting in a
1598 // stale summary).
1599 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1600 let stream = model.stream_completion_text(request, &cx);
1601 let Some(mut messages) = stream.await.log_err() else {
1602 thread
1603 .update(cx, |thread, _cx| {
1604 *thread.detailed_summary_tx.borrow_mut() =
1605 DetailedSummaryState::NotGenerated;
1606 })
1607 .ok()?;
1608 return None;
1609 };
1610
1611 let mut new_detailed_summary = String::new();
1612
1613 while let Some(chunk) = messages.stream.next().await {
1614 if let Some(chunk) = chunk.log_err() {
1615 new_detailed_summary.push_str(&chunk);
1616 }
1617 }
1618
1619 thread
1620 .update(cx, |thread, _cx| {
1621 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1622 text: new_detailed_summary.into(),
1623 message_id: last_message_id,
1624 };
1625 })
1626 .ok()?;
1627
1628 // Save thread so its summary can be reused later
1629 if let Some(thread) = thread.upgrade() {
1630 if let Ok(Ok(save_task)) = cx.update(|cx| {
1631 thread_store
1632 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1633 }) {
1634 save_task.await.log_err();
1635 }
1636 }
1637
1638 Some(())
1639 });
1640 }
1641
1642 pub async fn wait_for_detailed_summary_or_text(
1643 this: &Entity<Self>,
1644 cx: &mut AsyncApp,
1645 ) -> Option<SharedString> {
1646 let mut detailed_summary_rx = this
1647 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1648 .ok()?;
1649 loop {
1650 match detailed_summary_rx.recv().await? {
1651 DetailedSummaryState::Generating { .. } => {}
1652 DetailedSummaryState::NotGenerated => {
1653 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1654 }
1655 DetailedSummaryState::Generated { text, .. } => return Some(text),
1656 }
1657 }
1658 }
1659
1660 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1661 self.detailed_summary_rx
1662 .borrow()
1663 .text()
1664 .unwrap_or_else(|| self.text().into())
1665 }
1666
1667 pub fn is_generating_detailed_summary(&self) -> bool {
1668 matches!(
1669 &*self.detailed_summary_rx.borrow(),
1670 DetailedSummaryState::Generating { .. }
1671 )
1672 }
1673
1674 pub fn use_pending_tools(
1675 &mut self,
1676 window: Option<AnyWindowHandle>,
1677 cx: &mut Context<Self>,
1678 model: Arc<dyn LanguageModel>,
1679 ) -> Vec<PendingToolUse> {
1680 self.auto_capture_telemetry(cx);
1681 let request = self.to_completion_request(model, cx);
1682 let messages = Arc::new(request.messages);
1683 let pending_tool_uses = self
1684 .tool_use
1685 .pending_tool_uses()
1686 .into_iter()
1687 .filter(|tool_use| tool_use.status.is_idle())
1688 .cloned()
1689 .collect::<Vec<_>>();
1690
1691 for tool_use in pending_tool_uses.iter() {
1692 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1693 if tool.needs_confirmation(&tool_use.input, cx)
1694 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1695 {
1696 self.tool_use.confirm_tool_use(
1697 tool_use.id.clone(),
1698 tool_use.ui_text.clone(),
1699 tool_use.input.clone(),
1700 messages.clone(),
1701 tool,
1702 );
1703 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1704 } else {
1705 self.run_tool(
1706 tool_use.id.clone(),
1707 tool_use.ui_text.clone(),
1708 tool_use.input.clone(),
1709 &messages,
1710 tool,
1711 window,
1712 cx,
1713 );
1714 }
1715 }
1716 }
1717
1718 pending_tool_uses
1719 }
1720
1721 pub fn receive_invalid_tool_json(
1722 &mut self,
1723 tool_use_id: LanguageModelToolUseId,
1724 tool_name: Arc<str>,
1725 invalid_json: Arc<str>,
1726 error: String,
1727 window: Option<AnyWindowHandle>,
1728 cx: &mut Context<Thread>,
1729 ) {
1730 log::error!("The model returned invalid input JSON: {invalid_json}");
1731
1732 let pending_tool_use = self.tool_use.insert_tool_output(
1733 tool_use_id.clone(),
1734 tool_name,
1735 Err(anyhow!("Error parsing input JSON: {error}")),
1736 cx,
1737 );
1738 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1739 pending_tool_use.ui_text.clone()
1740 } else {
1741 log::error!(
1742 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1743 );
1744 format!("Unknown tool {}", tool_use_id).into()
1745 };
1746
1747 cx.emit(ThreadEvent::InvalidToolInput {
1748 tool_use_id: tool_use_id.clone(),
1749 ui_text,
1750 invalid_input_json: invalid_json,
1751 });
1752
1753 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1754 }
1755
1756 pub fn run_tool(
1757 &mut self,
1758 tool_use_id: LanguageModelToolUseId,
1759 ui_text: impl Into<SharedString>,
1760 input: serde_json::Value,
1761 messages: &[LanguageModelRequestMessage],
1762 tool: Arc<dyn Tool>,
1763 window: Option<AnyWindowHandle>,
1764 cx: &mut Context<Thread>,
1765 ) {
1766 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1767 self.tool_use
1768 .run_pending_tool(tool_use_id, ui_text.into(), task);
1769 }
1770
1771 fn spawn_tool_use(
1772 &mut self,
1773 tool_use_id: LanguageModelToolUseId,
1774 messages: &[LanguageModelRequestMessage],
1775 input: serde_json::Value,
1776 tool: Arc<dyn Tool>,
1777 window: Option<AnyWindowHandle>,
1778 cx: &mut Context<Thread>,
1779 ) -> Task<()> {
1780 let tool_name: Arc<str> = tool.name().into();
1781
1782 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1783 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1784 } else {
1785 tool.run(
1786 input,
1787 messages,
1788 self.project.clone(),
1789 self.action_log.clone(),
1790 window,
1791 cx,
1792 )
1793 };
1794
1795 // Store the card separately if it exists
1796 if let Some(card) = tool_result.card.clone() {
1797 self.tool_use
1798 .insert_tool_result_card(tool_use_id.clone(), card);
1799 }
1800
1801 cx.spawn({
1802 async move |thread: WeakEntity<Thread>, cx| {
1803 let output = tool_result.output.await;
1804
1805 thread
1806 .update(cx, |thread, cx| {
1807 let pending_tool_use = thread.tool_use.insert_tool_output(
1808 tool_use_id.clone(),
1809 tool_name,
1810 output,
1811 cx,
1812 );
1813 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1814 })
1815 .ok();
1816 }
1817 })
1818 }
1819
1820 fn tool_finished(
1821 &mut self,
1822 tool_use_id: LanguageModelToolUseId,
1823 pending_tool_use: Option<PendingToolUse>,
1824 canceled: bool,
1825 window: Option<AnyWindowHandle>,
1826 cx: &mut Context<Self>,
1827 ) {
1828 if self.all_tools_finished() {
1829 let model_registry = LanguageModelRegistry::read_global(cx);
1830 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1831 if !canceled {
1832 self.send_to_model(model, window, cx);
1833 }
1834 self.auto_capture_telemetry(cx);
1835 }
1836 }
1837
1838 cx.emit(ThreadEvent::ToolFinished {
1839 tool_use_id,
1840 pending_tool_use,
1841 });
1842 }
1843
1844 /// Cancels the last pending completion, if there are any pending.
1845 ///
1846 /// Returns whether a completion was canceled.
1847 pub fn cancel_last_completion(
1848 &mut self,
1849 window: Option<AnyWindowHandle>,
1850 cx: &mut Context<Self>,
1851 ) -> bool {
1852 let mut canceled = self.pending_completions.pop().is_some();
1853
1854 for pending_tool_use in self.tool_use.cancel_pending() {
1855 canceled = true;
1856 self.tool_finished(
1857 pending_tool_use.id.clone(),
1858 Some(pending_tool_use),
1859 true,
1860 window,
1861 cx,
1862 );
1863 }
1864
1865 self.finalize_pending_checkpoint(cx);
1866 canceled
1867 }
1868
1869 /// Signals that any in-progress editing should be canceled.
1870 ///
1871 /// This method is used to notify listeners (like ActiveThread) that
1872 /// they should cancel any editing operations.
1873 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
1874 cx.emit(ThreadEvent::CancelEditing);
1875 }
1876
1877 pub fn feedback(&self) -> Option<ThreadFeedback> {
1878 self.feedback
1879 }
1880
1881 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1882 self.message_feedback.get(&message_id).copied()
1883 }
1884
1885 pub fn report_message_feedback(
1886 &mut self,
1887 message_id: MessageId,
1888 feedback: ThreadFeedback,
1889 cx: &mut Context<Self>,
1890 ) -> Task<Result<()>> {
1891 if self.message_feedback.get(&message_id) == Some(&feedback) {
1892 return Task::ready(Ok(()));
1893 }
1894
1895 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1896 let serialized_thread = self.serialize(cx);
1897 let thread_id = self.id().clone();
1898 let client = self.project.read(cx).client();
1899
1900 let enabled_tool_names: Vec<String> = self
1901 .tools()
1902 .read(cx)
1903 .enabled_tools(cx)
1904 .iter()
1905 .map(|tool| tool.name().to_string())
1906 .collect();
1907
1908 self.message_feedback.insert(message_id, feedback);
1909
1910 cx.notify();
1911
1912 let message_content = self
1913 .message(message_id)
1914 .map(|msg| msg.to_string())
1915 .unwrap_or_default();
1916
1917 cx.background_spawn(async move {
1918 let final_project_snapshot = final_project_snapshot.await;
1919 let serialized_thread = serialized_thread.await?;
1920 let thread_data =
1921 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1922
1923 let rating = match feedback {
1924 ThreadFeedback::Positive => "positive",
1925 ThreadFeedback::Negative => "negative",
1926 };
1927 telemetry::event!(
1928 "Assistant Thread Rated",
1929 rating,
1930 thread_id,
1931 enabled_tool_names,
1932 message_id = message_id.0,
1933 message_content,
1934 thread_data,
1935 final_project_snapshot
1936 );
1937 client.telemetry().flush_events().await;
1938
1939 Ok(())
1940 })
1941 }
1942
1943 pub fn report_feedback(
1944 &mut self,
1945 feedback: ThreadFeedback,
1946 cx: &mut Context<Self>,
1947 ) -> Task<Result<()>> {
1948 let last_assistant_message_id = self
1949 .messages
1950 .iter()
1951 .rev()
1952 .find(|msg| msg.role == Role::Assistant)
1953 .map(|msg| msg.id);
1954
1955 if let Some(message_id) = last_assistant_message_id {
1956 self.report_message_feedback(message_id, feedback, cx)
1957 } else {
1958 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1959 let serialized_thread = self.serialize(cx);
1960 let thread_id = self.id().clone();
1961 let client = self.project.read(cx).client();
1962 self.feedback = Some(feedback);
1963 cx.notify();
1964
1965 cx.background_spawn(async move {
1966 let final_project_snapshot = final_project_snapshot.await;
1967 let serialized_thread = serialized_thread.await?;
1968 let thread_data = serde_json::to_value(serialized_thread)
1969 .unwrap_or_else(|_| serde_json::Value::Null);
1970
1971 let rating = match feedback {
1972 ThreadFeedback::Positive => "positive",
1973 ThreadFeedback::Negative => "negative",
1974 };
1975 telemetry::event!(
1976 "Assistant Thread Rated",
1977 rating,
1978 thread_id,
1979 thread_data,
1980 final_project_snapshot
1981 );
1982 client.telemetry().flush_events().await;
1983
1984 Ok(())
1985 })
1986 }
1987 }
1988
1989 /// Create a snapshot of the current project state including git information and unsaved buffers.
1990 fn project_snapshot(
1991 project: Entity<Project>,
1992 cx: &mut Context<Self>,
1993 ) -> Task<Arc<ProjectSnapshot>> {
1994 let git_store = project.read(cx).git_store().clone();
1995 let worktree_snapshots: Vec<_> = project
1996 .read(cx)
1997 .visible_worktrees(cx)
1998 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1999 .collect();
2000
2001 cx.spawn(async move |_, cx| {
2002 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2003
2004 let mut unsaved_buffers = Vec::new();
2005 cx.update(|app_cx| {
2006 let buffer_store = project.read(app_cx).buffer_store();
2007 for buffer_handle in buffer_store.read(app_cx).buffers() {
2008 let buffer = buffer_handle.read(app_cx);
2009 if buffer.is_dirty() {
2010 if let Some(file) = buffer.file() {
2011 let path = file.path().to_string_lossy().to_string();
2012 unsaved_buffers.push(path);
2013 }
2014 }
2015 }
2016 })
2017 .ok();
2018
2019 Arc::new(ProjectSnapshot {
2020 worktree_snapshots,
2021 unsaved_buffer_paths: unsaved_buffers,
2022 timestamp: Utc::now(),
2023 })
2024 })
2025 }
2026
2027 fn worktree_snapshot(
2028 worktree: Entity<project::Worktree>,
2029 git_store: Entity<GitStore>,
2030 cx: &App,
2031 ) -> Task<WorktreeSnapshot> {
2032 cx.spawn(async move |cx| {
2033 // Get worktree path and snapshot
2034 let worktree_info = cx.update(|app_cx| {
2035 let worktree = worktree.read(app_cx);
2036 let path = worktree.abs_path().to_string_lossy().to_string();
2037 let snapshot = worktree.snapshot();
2038 (path, snapshot)
2039 });
2040
2041 let Ok((worktree_path, _snapshot)) = worktree_info else {
2042 return WorktreeSnapshot {
2043 worktree_path: String::new(),
2044 git_state: None,
2045 };
2046 };
2047
2048 let git_state = git_store
2049 .update(cx, |git_store, cx| {
2050 git_store
2051 .repositories()
2052 .values()
2053 .find(|repo| {
2054 repo.read(cx)
2055 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2056 .is_some()
2057 })
2058 .cloned()
2059 })
2060 .ok()
2061 .flatten()
2062 .map(|repo| {
2063 repo.update(cx, |repo, _| {
2064 let current_branch =
2065 repo.branch.as_ref().map(|branch| branch.name.to_string());
2066 repo.send_job(None, |state, _| async move {
2067 let RepositoryState::Local { backend, .. } = state else {
2068 return GitState {
2069 remote_url: None,
2070 head_sha: None,
2071 current_branch,
2072 diff: None,
2073 };
2074 };
2075
2076 let remote_url = backend.remote_url("origin");
2077 let head_sha = backend.head_sha().await;
2078 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2079
2080 GitState {
2081 remote_url,
2082 head_sha,
2083 current_branch,
2084 diff,
2085 }
2086 })
2087 })
2088 });
2089
2090 let git_state = match git_state {
2091 Some(git_state) => match git_state.ok() {
2092 Some(git_state) => git_state.await.ok(),
2093 None => None,
2094 },
2095 None => None,
2096 };
2097
2098 WorktreeSnapshot {
2099 worktree_path,
2100 git_state,
2101 }
2102 })
2103 }
2104
2105 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2106 let mut markdown = Vec::new();
2107
2108 if let Some(summary) = self.summary() {
2109 writeln!(markdown, "# {summary}\n")?;
2110 };
2111
2112 for message in self.messages() {
2113 writeln!(
2114 markdown,
2115 "## {role}\n",
2116 role = match message.role {
2117 Role::User => "User",
2118 Role::Assistant => "Assistant",
2119 Role::System => "System",
2120 }
2121 )?;
2122
2123 if !message.loaded_context.text.is_empty() {
2124 writeln!(markdown, "{}", message.loaded_context.text)?;
2125 }
2126
2127 if !message.loaded_context.images.is_empty() {
2128 writeln!(
2129 markdown,
2130 "\n{} images attached as context.\n",
2131 message.loaded_context.images.len()
2132 )?;
2133 }
2134
2135 for segment in &message.segments {
2136 match segment {
2137 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2138 MessageSegment::Thinking { text, .. } => {
2139 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2140 }
2141 MessageSegment::RedactedThinking(_) => {}
2142 }
2143 }
2144
2145 for tool_use in self.tool_uses_for_message(message.id, cx) {
2146 writeln!(
2147 markdown,
2148 "**Use Tool: {} ({})**",
2149 tool_use.name, tool_use.id
2150 )?;
2151 writeln!(markdown, "```json")?;
2152 writeln!(
2153 markdown,
2154 "{}",
2155 serde_json::to_string_pretty(&tool_use.input)?
2156 )?;
2157 writeln!(markdown, "```")?;
2158 }
2159
2160 for tool_result in self.tool_results_for_message(message.id) {
2161 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2162 if tool_result.is_error {
2163 write!(markdown, " (Error)")?;
2164 }
2165
2166 writeln!(markdown, "**\n")?;
2167 writeln!(markdown, "{}", tool_result.content)?;
2168 }
2169 }
2170
2171 Ok(String::from_utf8_lossy(&markdown).to_string())
2172 }
2173
2174 pub fn keep_edits_in_range(
2175 &mut self,
2176 buffer: Entity<language::Buffer>,
2177 buffer_range: Range<language::Anchor>,
2178 cx: &mut Context<Self>,
2179 ) {
2180 self.action_log.update(cx, |action_log, cx| {
2181 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2182 });
2183 }
2184
2185 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2186 self.action_log
2187 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2188 }
2189
2190 pub fn reject_edits_in_ranges(
2191 &mut self,
2192 buffer: Entity<language::Buffer>,
2193 buffer_ranges: Vec<Range<language::Anchor>>,
2194 cx: &mut Context<Self>,
2195 ) -> Task<Result<()>> {
2196 self.action_log.update(cx, |action_log, cx| {
2197 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2198 })
2199 }
2200
2201 pub fn action_log(&self) -> &Entity<ActionLog> {
2202 &self.action_log
2203 }
2204
2205 pub fn project(&self) -> &Entity<Project> {
2206 &self.project
2207 }
2208
2209 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2210 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2211 return;
2212 }
2213
2214 let now = Instant::now();
2215 if let Some(last) = self.last_auto_capture_at {
2216 if now.duration_since(last).as_secs() < 10 {
2217 return;
2218 }
2219 }
2220
2221 self.last_auto_capture_at = Some(now);
2222
2223 let thread_id = self.id().clone();
2224 let github_login = self
2225 .project
2226 .read(cx)
2227 .user_store()
2228 .read(cx)
2229 .current_user()
2230 .map(|user| user.github_login.clone());
2231 let client = self.project.read(cx).client().clone();
2232 let serialize_task = self.serialize(cx);
2233
2234 cx.background_executor()
2235 .spawn(async move {
2236 if let Ok(serialized_thread) = serialize_task.await {
2237 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2238 telemetry::event!(
2239 "Agent Thread Auto-Captured",
2240 thread_id = thread_id.to_string(),
2241 thread_data = thread_data,
2242 auto_capture_reason = "tracked_user",
2243 github_login = github_login
2244 );
2245
2246 client.telemetry().flush_events().await;
2247 }
2248 }
2249 })
2250 .detach();
2251 }
2252
2253 pub fn cumulative_token_usage(&self) -> TokenUsage {
2254 self.cumulative_token_usage
2255 }
2256
2257 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2258 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2259 return TotalTokenUsage::default();
2260 };
2261
2262 let max = model.model.max_token_count();
2263
2264 let index = self
2265 .messages
2266 .iter()
2267 .position(|msg| msg.id == message_id)
2268 .unwrap_or(0);
2269
2270 if index == 0 {
2271 return TotalTokenUsage { total: 0, max };
2272 }
2273
2274 let token_usage = &self
2275 .request_token_usage
2276 .get(index - 1)
2277 .cloned()
2278 .unwrap_or_default();
2279
2280 TotalTokenUsage {
2281 total: token_usage.total_tokens() as usize,
2282 max,
2283 }
2284 }
2285
2286 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2287 let model_registry = LanguageModelRegistry::read_global(cx);
2288 let Some(model) = model_registry.default_model() else {
2289 return TotalTokenUsage::default();
2290 };
2291
2292 let max = model.model.max_token_count();
2293
2294 if let Some(exceeded_error) = &self.exceeded_window_error {
2295 if model.model.id() == exceeded_error.model_id {
2296 return TotalTokenUsage {
2297 total: exceeded_error.token_count,
2298 max,
2299 };
2300 }
2301 }
2302
2303 let total = self
2304 .token_usage_at_last_message()
2305 .unwrap_or_default()
2306 .total_tokens() as usize;
2307
2308 TotalTokenUsage { total, max }
2309 }
2310
2311 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2312 self.request_token_usage
2313 .get(self.messages.len().saturating_sub(1))
2314 .or_else(|| self.request_token_usage.last())
2315 .cloned()
2316 }
2317
2318 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2319 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2320 self.request_token_usage
2321 .resize(self.messages.len(), placeholder);
2322
2323 if let Some(last) = self.request_token_usage.last_mut() {
2324 *last = token_usage;
2325 }
2326 }
2327
2328 pub fn deny_tool_use(
2329 &mut self,
2330 tool_use_id: LanguageModelToolUseId,
2331 tool_name: Arc<str>,
2332 window: Option<AnyWindowHandle>,
2333 cx: &mut Context<Self>,
2334 ) {
2335 let err = Err(anyhow::anyhow!(
2336 "Permission to run tool action denied by user"
2337 ));
2338
2339 self.tool_use
2340 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2341 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2342 }
2343}
2344
2345#[derive(Debug, Clone, Error)]
2346pub enum ThreadError {
2347 #[error("Payment required")]
2348 PaymentRequired,
2349 #[error("Max monthly spend reached")]
2350 MaxMonthlySpendReached,
2351 #[error("Model request limit reached")]
2352 ModelRequestLimitReached { plan: Plan },
2353 #[error("Message {header}: {message}")]
2354 Message {
2355 header: SharedString,
2356 message: SharedString,
2357 },
2358}
2359
2360#[derive(Debug, Clone)]
2361pub enum ThreadEvent {
2362 ShowError(ThreadError),
2363 UsageUpdated(RequestUsage),
2364 StreamedCompletion,
2365 ReceivedTextChunk,
2366 StreamedAssistantText(MessageId, String),
2367 StreamedAssistantThinking(MessageId, String),
2368 StreamedToolUse {
2369 tool_use_id: LanguageModelToolUseId,
2370 ui_text: Arc<str>,
2371 input: serde_json::Value,
2372 },
2373 InvalidToolInput {
2374 tool_use_id: LanguageModelToolUseId,
2375 ui_text: Arc<str>,
2376 invalid_input_json: Arc<str>,
2377 },
2378 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2379 MessageAdded(MessageId),
2380 MessageEdited(MessageId),
2381 MessageDeleted(MessageId),
2382 SummaryGenerated,
2383 SummaryChanged,
2384 UsePendingTools {
2385 tool_uses: Vec<PendingToolUse>,
2386 },
2387 ToolFinished {
2388 #[allow(unused)]
2389 tool_use_id: LanguageModelToolUseId,
2390 /// The pending tool use that corresponds to this tool.
2391 pending_tool_use: Option<PendingToolUse>,
2392 },
2393 CheckpointChanged,
2394 ToolConfirmationNeeded,
2395 CancelEditing,
2396}
2397
2398impl EventEmitter<ThreadEvent> for Thread {}
2399
2400struct PendingCompletion {
2401 id: usize,
2402 _task: Task<()>,
2403}
2404
2405#[cfg(test)]
2406mod tests {
2407 use super::*;
2408 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2409 use assistant_settings::AssistantSettings;
2410 use assistant_tool::ToolRegistry;
2411 use context_server::ContextServerSettings;
2412 use editor::EditorSettings;
2413 use gpui::TestAppContext;
2414 use language_model::fake_provider::FakeLanguageModel;
2415 use project::{FakeFs, Project};
2416 use prompt_store::PromptBuilder;
2417 use serde_json::json;
2418 use settings::{Settings, SettingsStore};
2419 use std::sync::Arc;
2420 use theme::ThemeSettings;
2421 use util::path;
2422 use workspace::Workspace;
2423
2424 #[gpui::test]
2425 async fn test_message_with_context(cx: &mut TestAppContext) {
2426 init_test_settings(cx);
2427
2428 let project = create_test_project(
2429 cx,
2430 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2431 )
2432 .await;
2433
2434 let (_workspace, _thread_store, thread, context_store, model) =
2435 setup_test_environment(cx, project.clone()).await;
2436
2437 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2438 .await
2439 .unwrap();
2440
2441 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2442 let loaded_context = cx
2443 .update(|cx| load_context(vec![context], &project, &None, cx))
2444 .await;
2445
2446 // Insert user message with context
2447 let message_id = thread.update(cx, |thread, cx| {
2448 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2449 });
2450
2451 // Check content and context in message object
2452 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2453
2454 // Use different path format strings based on platform for the test
2455 #[cfg(windows)]
2456 let path_part = r"test\code.rs";
2457 #[cfg(not(windows))]
2458 let path_part = "test/code.rs";
2459
2460 let expected_context = format!(
2461 r#"
2462<context>
2463The following items were attached by the user. You don't need to use other tools to read them.
2464
2465<files>
2466```rs {path_part}
2467fn main() {{
2468 println!("Hello, world!");
2469}}
2470```
2471</files>
2472</context>
2473"#
2474 );
2475
2476 assert_eq!(message.role, Role::User);
2477 assert_eq!(message.segments.len(), 1);
2478 assert_eq!(
2479 message.segments[0],
2480 MessageSegment::Text("Please explain this code".to_string())
2481 );
2482 assert_eq!(message.loaded_context.text, expected_context);
2483
2484 // Check message in request
2485 let request = thread.update(cx, |thread, cx| {
2486 thread.to_completion_request(model.clone(), cx)
2487 });
2488
2489 assert_eq!(request.messages.len(), 2);
2490 let expected_full_message = format!("{}Please explain this code", expected_context);
2491 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2492 }
2493
2494 #[gpui::test]
2495 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2496 init_test_settings(cx);
2497
2498 let project = create_test_project(
2499 cx,
2500 json!({
2501 "file1.rs": "fn function1() {}\n",
2502 "file2.rs": "fn function2() {}\n",
2503 "file3.rs": "fn function3() {}\n",
2504 }),
2505 )
2506 .await;
2507
2508 let (_, _thread_store, thread, context_store, model) =
2509 setup_test_environment(cx, project.clone()).await;
2510
2511 // First message with context 1
2512 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2513 .await
2514 .unwrap();
2515 let new_contexts = context_store.update(cx, |store, cx| {
2516 store.new_context_for_thread(thread.read(cx))
2517 });
2518 assert_eq!(new_contexts.len(), 1);
2519 let loaded_context = cx
2520 .update(|cx| load_context(new_contexts, &project, &None, cx))
2521 .await;
2522 let message1_id = thread.update(cx, |thread, cx| {
2523 thread.insert_user_message("Message 1", loaded_context, None, cx)
2524 });
2525
2526 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2527 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2528 .await
2529 .unwrap();
2530 let new_contexts = context_store.update(cx, |store, cx| {
2531 store.new_context_for_thread(thread.read(cx))
2532 });
2533 assert_eq!(new_contexts.len(), 1);
2534 let loaded_context = cx
2535 .update(|cx| load_context(new_contexts, &project, &None, cx))
2536 .await;
2537 let message2_id = thread.update(cx, |thread, cx| {
2538 thread.insert_user_message("Message 2", loaded_context, None, cx)
2539 });
2540
2541 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2542 //
2543 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2544 .await
2545 .unwrap();
2546 let new_contexts = context_store.update(cx, |store, cx| {
2547 store.new_context_for_thread(thread.read(cx))
2548 });
2549 assert_eq!(new_contexts.len(), 1);
2550 let loaded_context = cx
2551 .update(|cx| load_context(new_contexts, &project, &None, cx))
2552 .await;
2553 let message3_id = thread.update(cx, |thread, cx| {
2554 thread.insert_user_message("Message 3", loaded_context, None, cx)
2555 });
2556
2557 // Check what contexts are included in each message
2558 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2559 (
2560 thread.message(message1_id).unwrap().clone(),
2561 thread.message(message2_id).unwrap().clone(),
2562 thread.message(message3_id).unwrap().clone(),
2563 )
2564 });
2565
2566 // First message should include context 1
2567 assert!(message1.loaded_context.text.contains("file1.rs"));
2568
2569 // Second message should include only context 2 (not 1)
2570 assert!(!message2.loaded_context.text.contains("file1.rs"));
2571 assert!(message2.loaded_context.text.contains("file2.rs"));
2572
2573 // Third message should include only context 3 (not 1 or 2)
2574 assert!(!message3.loaded_context.text.contains("file1.rs"));
2575 assert!(!message3.loaded_context.text.contains("file2.rs"));
2576 assert!(message3.loaded_context.text.contains("file3.rs"));
2577
2578 // Check entire request to make sure all contexts are properly included
2579 let request = thread.update(cx, |thread, cx| {
2580 thread.to_completion_request(model.clone(), cx)
2581 });
2582
2583 // The request should contain all 3 messages
2584 assert_eq!(request.messages.len(), 4);
2585
2586 // Check that the contexts are properly formatted in each message
2587 assert!(request.messages[1].string_contents().contains("file1.rs"));
2588 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2589 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2590
2591 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2592 assert!(request.messages[2].string_contents().contains("file2.rs"));
2593 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2594
2595 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2596 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2597 assert!(request.messages[3].string_contents().contains("file3.rs"));
2598 }
2599
2600 #[gpui::test]
2601 async fn test_message_without_files(cx: &mut TestAppContext) {
2602 init_test_settings(cx);
2603
2604 let project = create_test_project(
2605 cx,
2606 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2607 )
2608 .await;
2609
2610 let (_, _thread_store, thread, _context_store, model) =
2611 setup_test_environment(cx, project.clone()).await;
2612
2613 // Insert user message without any context (empty context vector)
2614 let message_id = thread.update(cx, |thread, cx| {
2615 thread.insert_user_message(
2616 "What is the best way to learn Rust?",
2617 ContextLoadResult::default(),
2618 None,
2619 cx,
2620 )
2621 });
2622
2623 // Check content and context in message object
2624 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2625
2626 // Context should be empty when no files are included
2627 assert_eq!(message.role, Role::User);
2628 assert_eq!(message.segments.len(), 1);
2629 assert_eq!(
2630 message.segments[0],
2631 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2632 );
2633 assert_eq!(message.loaded_context.text, "");
2634
2635 // Check message in request
2636 let request = thread.update(cx, |thread, cx| {
2637 thread.to_completion_request(model.clone(), cx)
2638 });
2639
2640 assert_eq!(request.messages.len(), 2);
2641 assert_eq!(
2642 request.messages[1].string_contents(),
2643 "What is the best way to learn Rust?"
2644 );
2645
2646 // Add second message, also without context
2647 let message2_id = thread.update(cx, |thread, cx| {
2648 thread.insert_user_message(
2649 "Are there any good books?",
2650 ContextLoadResult::default(),
2651 None,
2652 cx,
2653 )
2654 });
2655
2656 let message2 =
2657 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2658 assert_eq!(message2.loaded_context.text, "");
2659
2660 // Check that both messages appear in the request
2661 let request = thread.update(cx, |thread, cx| {
2662 thread.to_completion_request(model.clone(), cx)
2663 });
2664
2665 assert_eq!(request.messages.len(), 3);
2666 assert_eq!(
2667 request.messages[1].string_contents(),
2668 "What is the best way to learn Rust?"
2669 );
2670 assert_eq!(
2671 request.messages[2].string_contents(),
2672 "Are there any good books?"
2673 );
2674 }
2675
2676 #[gpui::test]
2677 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2678 init_test_settings(cx);
2679
2680 let project = create_test_project(
2681 cx,
2682 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2683 )
2684 .await;
2685
2686 let (_workspace, _thread_store, thread, context_store, model) =
2687 setup_test_environment(cx, project.clone()).await;
2688
2689 // Open buffer and add it to context
2690 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2691 .await
2692 .unwrap();
2693
2694 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2695 let loaded_context = cx
2696 .update(|cx| load_context(vec![context], &project, &None, cx))
2697 .await;
2698
2699 // Insert user message with the buffer as context
2700 thread.update(cx, |thread, cx| {
2701 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2702 });
2703
2704 // Create a request and check that it doesn't have a stale buffer warning yet
2705 let initial_request = thread.update(cx, |thread, cx| {
2706 thread.to_completion_request(model.clone(), cx)
2707 });
2708
2709 // Make sure we don't have a stale file warning yet
2710 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2711 msg.string_contents()
2712 .contains("These files changed since last read:")
2713 });
2714 assert!(
2715 !has_stale_warning,
2716 "Should not have stale buffer warning before buffer is modified"
2717 );
2718
2719 // Modify the buffer
2720 buffer.update(cx, |buffer, cx| {
2721 // Find a position at the end of line 1
2722 buffer.edit(
2723 [(1..1, "\n println!(\"Added a new line\");\n")],
2724 None,
2725 cx,
2726 );
2727 });
2728
2729 // Insert another user message without context
2730 thread.update(cx, |thread, cx| {
2731 thread.insert_user_message(
2732 "What does the code do now?",
2733 ContextLoadResult::default(),
2734 None,
2735 cx,
2736 )
2737 });
2738
2739 // Create a new request and check for the stale buffer warning
2740 let new_request = thread.update(cx, |thread, cx| {
2741 thread.to_completion_request(model.clone(), cx)
2742 });
2743
2744 // We should have a stale file warning as the last message
2745 let last_message = new_request
2746 .messages
2747 .last()
2748 .expect("Request should have messages");
2749
2750 // The last message should be the stale buffer notification
2751 assert_eq!(last_message.role, Role::User);
2752
2753 // Check the exact content of the message
2754 let expected_content = "These files changed since last read:\n- code.rs\n";
2755 assert_eq!(
2756 last_message.string_contents(),
2757 expected_content,
2758 "Last message should be exactly the stale buffer notification"
2759 );
2760 }
2761
2762 fn init_test_settings(cx: &mut TestAppContext) {
2763 cx.update(|cx| {
2764 let settings_store = SettingsStore::test(cx);
2765 cx.set_global(settings_store);
2766 language::init(cx);
2767 Project::init_settings(cx);
2768 AssistantSettings::register(cx);
2769 prompt_store::init(cx);
2770 thread_store::init(cx);
2771 workspace::init_settings(cx);
2772 ThemeSettings::register(cx);
2773 ContextServerSettings::register(cx);
2774 EditorSettings::register(cx);
2775 ToolRegistry::default_global(cx);
2776 });
2777 }
2778
2779 // Helper to create a test project with test files
2780 async fn create_test_project(
2781 cx: &mut TestAppContext,
2782 files: serde_json::Value,
2783 ) -> Entity<Project> {
2784 let fs = FakeFs::new(cx.executor());
2785 fs.insert_tree(path!("/test"), files).await;
2786 Project::test(fs, [path!("/test").as_ref()], cx).await
2787 }
2788
2789 async fn setup_test_environment(
2790 cx: &mut TestAppContext,
2791 project: Entity<Project>,
2792 ) -> (
2793 Entity<Workspace>,
2794 Entity<ThreadStore>,
2795 Entity<Thread>,
2796 Entity<ContextStore>,
2797 Arc<dyn LanguageModel>,
2798 ) {
2799 let (workspace, cx) =
2800 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2801
2802 let thread_store = cx
2803 .update(|_, cx| {
2804 ThreadStore::load(
2805 project.clone(),
2806 cx.new(|_| ToolWorkingSet::default()),
2807 None,
2808 Arc::new(PromptBuilder::new(None).unwrap()),
2809 cx,
2810 )
2811 })
2812 .await
2813 .unwrap();
2814
2815 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2816 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2817
2818 let model = FakeLanguageModel::default();
2819 let model: Arc<dyn LanguageModel> = Arc::new(model);
2820
2821 (workspace, thread_store, thread, context_store, model)
2822 }
2823
2824 async fn add_file_to_context(
2825 project: &Entity<Project>,
2826 context_store: &Entity<ContextStore>,
2827 path: &str,
2828 cx: &mut TestAppContext,
2829 ) -> Result<Entity<language::Buffer>> {
2830 let buffer_path = project
2831 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2832 .unwrap();
2833
2834 let buffer = project
2835 .update(cx, |project, cx| {
2836 project.open_buffer(buffer_path.clone(), cx)
2837 })
2838 .await
2839 .unwrap();
2840
2841 context_store.update(cx, |context_store, cx| {
2842 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2843 });
2844
2845 Ok(buffer)
2846 }
2847}