1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
22 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
23 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::{ModelContext, PromptBuilder};
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37use zed_llm_client::CompletionMode;
38
39use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
40use crate::thread_store::{
41 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
42 SerializedToolUse, SharedProjectContext,
43};
44use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
45
46#[derive(
47 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
48)]
49pub struct ThreadId(Arc<str>);
50
51impl ThreadId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for ThreadId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63impl From<&str> for ThreadId {
64 fn from(value: &str) -> Self {
65 Self(value.into())
66 }
67}
68
69/// The ID of the user prompt that initiated a request.
70///
71/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
73pub struct PromptId(Arc<str>);
74
75impl PromptId {
76 pub fn new() -> Self {
77 Self(Uuid::new_v4().to_string().into())
78 }
79}
80
81impl std::fmt::Display for PromptId {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(f, "{}", self.0)
84 }
85}
86
87#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
88pub struct MessageId(pub(crate) usize);
89
90impl MessageId {
91 fn post_inc(&mut self) -> Self {
92 Self(post_inc(&mut self.0))
93 }
94}
95
96/// A message in a [`Thread`].
97#[derive(Debug, Clone)]
98pub struct Message {
99 pub id: MessageId,
100 pub role: Role,
101 pub segments: Vec<MessageSegment>,
102 pub loaded_context: LoadedContext,
103}
104
105impl Message {
106 /// Returns whether the message contains any meaningful text that should be displayed
107 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
108 pub fn should_display_content(&self) -> bool {
109 self.segments.iter().all(|segment| segment.should_display())
110 }
111
112 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
113 if let Some(MessageSegment::Thinking {
114 text: segment,
115 signature: current_signature,
116 }) = self.segments.last_mut()
117 {
118 if let Some(signature) = signature {
119 *current_signature = Some(signature);
120 }
121 segment.push_str(text);
122 } else {
123 self.segments.push(MessageSegment::Thinking {
124 text: text.to_string(),
125 signature,
126 });
127 }
128 }
129
130 pub fn push_text(&mut self, text: &str) {
131 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
132 segment.push_str(text);
133 } else {
134 self.segments.push(MessageSegment::Text(text.to_string()));
135 }
136 }
137
138 pub fn to_string(&self) -> String {
139 let mut result = String::new();
140
141 if !self.loaded_context.text.is_empty() {
142 result.push_str(&self.loaded_context.text);
143 }
144
145 for segment in &self.segments {
146 match segment {
147 MessageSegment::Text(text) => result.push_str(text),
148 MessageSegment::Thinking { text, .. } => {
149 result.push_str("<think>\n");
150 result.push_str(text);
151 result.push_str("\n</think>");
152 }
153 MessageSegment::RedactedThinking(_) => {}
154 }
155 }
156
157 result
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub enum MessageSegment {
163 Text(String),
164 Thinking {
165 text: String,
166 signature: Option<String>,
167 },
168 RedactedThinking(Vec<u8>),
169}
170
171impl MessageSegment {
172 pub fn should_display(&self) -> bool {
173 match self {
174 Self::Text(text) => text.is_empty(),
175 Self::Thinking { text, .. } => text.is_empty(),
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 completion_mode: Option<CompletionMode>,
295 messages: Vec<Message>,
296 next_message_id: MessageId,
297 last_prompt_id: PromptId,
298 project_context: SharedProjectContext,
299 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
300 completion_count: usize,
301 pending_completions: Vec<PendingCompletion>,
302 project: Entity<Project>,
303 prompt_builder: Arc<PromptBuilder>,
304 tools: Entity<ToolWorkingSet>,
305 tool_use: ToolUseState,
306 action_log: Entity<ActionLog>,
307 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
308 pending_checkpoint: Option<ThreadCheckpoint>,
309 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
310 request_token_usage: Vec<TokenUsage>,
311 cumulative_token_usage: TokenUsage,
312 exceeded_window_error: Option<ExceededWindowError>,
313 feedback: Option<ThreadFeedback>,
314 message_feedback: HashMap<MessageId, ThreadFeedback>,
315 last_auto_capture_at: Option<Instant>,
316 request_callback: Option<
317 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
318 >,
319 remaining_turns: u32,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct ExceededWindowError {
324 /// Model used when last message exceeded context window
325 model_id: LanguageModelId,
326 /// Token count including last message
327 token_count: usize,
328}
329
330impl Thread {
331 pub fn new(
332 project: Entity<Project>,
333 tools: Entity<ToolWorkingSet>,
334 prompt_builder: Arc<PromptBuilder>,
335 system_prompt: SharedProjectContext,
336 cx: &mut Context<Self>,
337 ) -> Self {
338 Self {
339 id: ThreadId::new(),
340 updated_at: Utc::now(),
341 summary: None,
342 pending_summary: Task::ready(None),
343 detailed_summary_state: DetailedSummaryState::NotGenerated,
344 completion_mode: None,
345 messages: Vec::new(),
346 next_message_id: MessageId(0),
347 last_prompt_id: PromptId::new(),
348 project_context: system_prompt,
349 checkpoints_by_message: HashMap::default(),
350 completion_count: 0,
351 pending_completions: Vec::new(),
352 project: project.clone(),
353 prompt_builder,
354 tools: tools.clone(),
355 last_restore_checkpoint: None,
356 pending_checkpoint: None,
357 tool_use: ToolUseState::new(tools.clone()),
358 action_log: cx.new(|_| ActionLog::new(project.clone())),
359 initial_project_snapshot: {
360 let project_snapshot = Self::project_snapshot(project, cx);
361 cx.foreground_executor()
362 .spawn(async move { Some(project_snapshot.await) })
363 .shared()
364 },
365 request_token_usage: Vec::new(),
366 cumulative_token_usage: TokenUsage::default(),
367 exceeded_window_error: None,
368 feedback: None,
369 message_feedback: HashMap::default(),
370 last_auto_capture_at: None,
371 request_callback: None,
372 remaining_turns: u32::MAX,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
393
394 Self {
395 id,
396 updated_at: serialized.updated_at,
397 summary: Some(serialized.summary),
398 pending_summary: Task::ready(None),
399 detailed_summary_state: serialized.detailed_summary_state,
400 completion_mode: None,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 loaded_context: LoadedContext {
421 contexts: Vec::new(),
422 text: message.context,
423 images: Vec::new(),
424 },
425 })
426 .collect(),
427 next_message_id,
428 last_prompt_id: PromptId::new(),
429 project_context,
430 checkpoints_by_message: HashMap::default(),
431 completion_count: 0,
432 pending_completions: Vec::new(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 project: project.clone(),
436 prompt_builder,
437 tools,
438 tool_use,
439 action_log: cx.new(|_| ActionLog::new(project)),
440 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
441 request_token_usage: serialized.request_token_usage,
442 cumulative_token_usage: serialized.cumulative_token_usage,
443 exceeded_window_error: None,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 request_callback: None,
448 remaining_turns: u32::MAX,
449 }
450 }
451
452 pub fn set_request_callback(
453 &mut self,
454 callback: impl 'static
455 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
456 ) {
457 self.request_callback = Some(Box::new(callback));
458 }
459
460 pub fn id(&self) -> &ThreadId {
461 &self.id
462 }
463
464 pub fn is_empty(&self) -> bool {
465 self.messages.is_empty()
466 }
467
468 pub fn updated_at(&self) -> DateTime<Utc> {
469 self.updated_at
470 }
471
472 pub fn touch_updated_at(&mut self) {
473 self.updated_at = Utc::now();
474 }
475
476 pub fn advance_prompt_id(&mut self) {
477 self.last_prompt_id = PromptId::new();
478 }
479
480 pub fn summary(&self) -> Option<SharedString> {
481 self.summary.clone()
482 }
483
484 pub fn project_context(&self) -> SharedProjectContext {
485 self.project_context.clone()
486 }
487
488 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
489
490 pub fn summary_or_default(&self) -> SharedString {
491 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
492 }
493
494 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
495 let Some(current_summary) = &self.summary else {
496 // Don't allow setting summary until generated
497 return;
498 };
499
500 let mut new_summary = new_summary.into();
501
502 if new_summary.is_empty() {
503 new_summary = Self::DEFAULT_SUMMARY;
504 }
505
506 if current_summary != &new_summary {
507 self.summary = Some(new_summary);
508 cx.emit(ThreadEvent::SummaryChanged);
509 }
510 }
511
512 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
513 self.latest_detailed_summary()
514 .unwrap_or_else(|| self.text().into())
515 }
516
517 fn latest_detailed_summary(&self) -> Option<SharedString> {
518 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
519 Some(text.clone())
520 } else {
521 None
522 }
523 }
524
525 pub fn completion_mode(&self) -> Option<CompletionMode> {
526 self.completion_mode
527 }
528
529 pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
530 self.completion_mode = mode;
531 }
532
533 pub fn message(&self, id: MessageId) -> Option<&Message> {
534 let index = self
535 .messages
536 .binary_search_by(|message| message.id.cmp(&id))
537 .ok()?;
538
539 self.messages.get(index)
540 }
541
542 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
543 self.messages.iter()
544 }
545
546 pub fn is_generating(&self) -> bool {
547 !self.pending_completions.is_empty() || !self.all_tools_finished()
548 }
549
550 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
551 &self.tools
552 }
553
554 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
555 self.tool_use
556 .pending_tool_uses()
557 .into_iter()
558 .find(|tool_use| &tool_use.id == id)
559 }
560
561 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
562 self.tool_use
563 .pending_tool_uses()
564 .into_iter()
565 .filter(|tool_use| tool_use.status.needs_confirmation())
566 }
567
568 pub fn has_pending_tool_uses(&self) -> bool {
569 !self.tool_use.pending_tool_uses().is_empty()
570 }
571
572 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
573 self.checkpoints_by_message.get(&id).cloned()
574 }
575
576 pub fn restore_checkpoint(
577 &mut self,
578 checkpoint: ThreadCheckpoint,
579 cx: &mut Context<Self>,
580 ) -> Task<Result<()>> {
581 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
582 message_id: checkpoint.message_id,
583 });
584 cx.emit(ThreadEvent::CheckpointChanged);
585 cx.notify();
586
587 let git_store = self.project().read(cx).git_store().clone();
588 let restore = git_store.update(cx, |git_store, cx| {
589 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
590 });
591
592 cx.spawn(async move |this, cx| {
593 let result = restore.await;
594 this.update(cx, |this, cx| {
595 if let Err(err) = result.as_ref() {
596 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
597 message_id: checkpoint.message_id,
598 error: err.to_string(),
599 });
600 } else {
601 this.truncate(checkpoint.message_id, cx);
602 this.last_restore_checkpoint = None;
603 }
604 this.pending_checkpoint = None;
605 cx.emit(ThreadEvent::CheckpointChanged);
606 cx.notify();
607 })?;
608 result
609 })
610 }
611
612 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
613 let pending_checkpoint = if self.is_generating() {
614 return;
615 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
616 checkpoint
617 } else {
618 return;
619 };
620
621 let git_store = self.project.read(cx).git_store().clone();
622 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
623 cx.spawn(async move |this, cx| match final_checkpoint.await {
624 Ok(final_checkpoint) => {
625 let equal = git_store
626 .update(cx, |store, cx| {
627 store.compare_checkpoints(
628 pending_checkpoint.git_checkpoint.clone(),
629 final_checkpoint.clone(),
630 cx,
631 )
632 })?
633 .await
634 .unwrap_or(false);
635
636 if !equal {
637 this.update(cx, |this, cx| {
638 this.insert_checkpoint(pending_checkpoint, cx)
639 })?;
640 }
641
642 Ok(())
643 }
644 Err(_) => this.update(cx, |this, cx| {
645 this.insert_checkpoint(pending_checkpoint, cx)
646 }),
647 })
648 .detach();
649 }
650
651 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
652 self.checkpoints_by_message
653 .insert(checkpoint.message_id, checkpoint);
654 cx.emit(ThreadEvent::CheckpointChanged);
655 cx.notify();
656 }
657
658 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
659 self.last_restore_checkpoint.as_ref()
660 }
661
662 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
663 let Some(message_ix) = self
664 .messages
665 .iter()
666 .rposition(|message| message.id == message_id)
667 else {
668 return;
669 };
670 for deleted_message in self.messages.drain(message_ix..) {
671 self.checkpoints_by_message.remove(&deleted_message.id);
672 }
673 cx.notify();
674 }
675
676 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
677 self.messages
678 .iter()
679 .find(|message| message.id == id)
680 .into_iter()
681 .flat_map(|message| message.loaded_context.contexts.iter())
682 }
683
684 pub fn is_turn_end(&self, ix: usize) -> bool {
685 if self.messages.is_empty() {
686 return false;
687 }
688
689 if !self.is_generating() && ix == self.messages.len() - 1 {
690 return true;
691 }
692
693 let Some(message) = self.messages.get(ix) else {
694 return false;
695 };
696
697 if message.role != Role::Assistant {
698 return false;
699 }
700
701 self.messages
702 .get(ix + 1)
703 .and_then(|message| {
704 self.message(message.id)
705 .map(|next_message| next_message.role == Role::User)
706 })
707 .unwrap_or(false)
708 }
709
710 /// Returns whether all of the tool uses have finished running.
711 pub fn all_tools_finished(&self) -> bool {
712 // If the only pending tool uses left are the ones with errors, then
713 // that means that we've finished running all of the pending tools.
714 self.tool_use
715 .pending_tool_uses()
716 .iter()
717 .all(|tool_use| tool_use.status.is_error())
718 }
719
720 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
721 self.tool_use.tool_uses_for_message(id, cx)
722 }
723
724 pub fn tool_results_for_message(
725 &self,
726 assistant_message_id: MessageId,
727 ) -> Vec<&LanguageModelToolResult> {
728 self.tool_use.tool_results_for_message(assistant_message_id)
729 }
730
731 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
732 self.tool_use.tool_result(id)
733 }
734
735 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
736 Some(&self.tool_use.tool_result(id)?.content)
737 }
738
739 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
740 self.tool_use.tool_result_card(id).cloned()
741 }
742
743 /// Return tools that are both enabled and supported by the model
744 pub fn available_tools(
745 &self,
746 cx: &App,
747 model: Arc<dyn LanguageModel>,
748 ) -> Vec<LanguageModelRequestTool> {
749 if model.supports_tools() {
750 self.tools()
751 .read(cx)
752 .enabled_tools(cx)
753 .into_iter()
754 .filter_map(|tool| {
755 // Skip tools that cannot be supported
756 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
757 Some(LanguageModelRequestTool {
758 name: tool.name(),
759 description: tool.description(),
760 input_schema,
761 })
762 })
763 .collect()
764 } else {
765 Vec::default()
766 }
767 }
768
769 pub fn insert_user_message(
770 &mut self,
771 text: impl Into<String>,
772 loaded_context: ContextLoadResult,
773 git_checkpoint: Option<GitStoreCheckpoint>,
774 cx: &mut Context<Self>,
775 ) -> MessageId {
776 if !loaded_context.referenced_buffers.is_empty() {
777 self.action_log.update(cx, |log, cx| {
778 for buffer in loaded_context.referenced_buffers {
779 log.track_buffer(buffer, cx);
780 }
781 });
782 }
783
784 let message_id = self.insert_message(
785 Role::User,
786 vec![MessageSegment::Text(text.into())],
787 loaded_context.loaded_context,
788 cx,
789 );
790
791 if let Some(git_checkpoint) = git_checkpoint {
792 self.pending_checkpoint = Some(ThreadCheckpoint {
793 message_id,
794 git_checkpoint,
795 });
796 }
797
798 self.auto_capture_telemetry(cx);
799
800 message_id
801 }
802
803 pub fn insert_assistant_message(
804 &mut self,
805 segments: Vec<MessageSegment>,
806 cx: &mut Context<Self>,
807 ) -> MessageId {
808 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
809 }
810
811 pub fn insert_message(
812 &mut self,
813 role: Role,
814 segments: Vec<MessageSegment>,
815 loaded_context: LoadedContext,
816 cx: &mut Context<Self>,
817 ) -> MessageId {
818 let id = self.next_message_id.post_inc();
819 self.messages.push(Message {
820 id,
821 role,
822 segments,
823 loaded_context,
824 });
825 self.touch_updated_at();
826 cx.emit(ThreadEvent::MessageAdded(id));
827 id
828 }
829
830 pub fn edit_message(
831 &mut self,
832 id: MessageId,
833 new_role: Role,
834 new_segments: Vec<MessageSegment>,
835 cx: &mut Context<Self>,
836 ) -> bool {
837 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
838 return false;
839 };
840 message.role = new_role;
841 message.segments = new_segments;
842 self.touch_updated_at();
843 cx.emit(ThreadEvent::MessageEdited(id));
844 true
845 }
846
847 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
848 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
849 return false;
850 };
851 self.messages.remove(index);
852 self.touch_updated_at();
853 cx.emit(ThreadEvent::MessageDeleted(id));
854 true
855 }
856
857 /// Returns the representation of this [`Thread`] in a textual form.
858 ///
859 /// This is the representation we use when attaching a thread as context to another thread.
860 pub fn text(&self) -> String {
861 let mut text = String::new();
862
863 for message in &self.messages {
864 text.push_str(match message.role {
865 language_model::Role::User => "User:",
866 language_model::Role::Assistant => "Assistant:",
867 language_model::Role::System => "System:",
868 });
869 text.push('\n');
870
871 for segment in &message.segments {
872 match segment {
873 MessageSegment::Text(content) => text.push_str(content),
874 MessageSegment::Thinking { text: content, .. } => {
875 text.push_str(&format!("<think>{}</think>", content))
876 }
877 MessageSegment::RedactedThinking(_) => {}
878 }
879 }
880 text.push('\n');
881 }
882
883 text
884 }
885
886 /// Serializes this thread into a format for storage or telemetry.
887 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
888 let initial_project_snapshot = self.initial_project_snapshot.clone();
889 cx.spawn(async move |this, cx| {
890 let initial_project_snapshot = initial_project_snapshot.await;
891 this.read_with(cx, |this, cx| SerializedThread {
892 version: SerializedThread::VERSION.to_string(),
893 summary: this.summary_or_default(),
894 updated_at: this.updated_at(),
895 messages: this
896 .messages()
897 .map(|message| SerializedMessage {
898 id: message.id,
899 role: message.role,
900 segments: message
901 .segments
902 .iter()
903 .map(|segment| match segment {
904 MessageSegment::Text(text) => {
905 SerializedMessageSegment::Text { text: text.clone() }
906 }
907 MessageSegment::Thinking { text, signature } => {
908 SerializedMessageSegment::Thinking {
909 text: text.clone(),
910 signature: signature.clone(),
911 }
912 }
913 MessageSegment::RedactedThinking(data) => {
914 SerializedMessageSegment::RedactedThinking {
915 data: data.clone(),
916 }
917 }
918 })
919 .collect(),
920 tool_uses: this
921 .tool_uses_for_message(message.id, cx)
922 .into_iter()
923 .map(|tool_use| SerializedToolUse {
924 id: tool_use.id,
925 name: tool_use.name,
926 input: tool_use.input,
927 })
928 .collect(),
929 tool_results: this
930 .tool_results_for_message(message.id)
931 .into_iter()
932 .map(|tool_result| SerializedToolResult {
933 tool_use_id: tool_result.tool_use_id.clone(),
934 is_error: tool_result.is_error,
935 content: tool_result.content.clone(),
936 })
937 .collect(),
938 context: message.loaded_context.text.clone(),
939 })
940 .collect(),
941 initial_project_snapshot,
942 cumulative_token_usage: this.cumulative_token_usage,
943 request_token_usage: this.request_token_usage.clone(),
944 detailed_summary_state: this.detailed_summary_state.clone(),
945 exceeded_window_error: this.exceeded_window_error.clone(),
946 })
947 })
948 }
949
950 pub fn remaining_turns(&self) -> u32 {
951 self.remaining_turns
952 }
953
954 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
955 self.remaining_turns = remaining_turns;
956 }
957
958 pub fn send_to_model(
959 &mut self,
960 model: Arc<dyn LanguageModel>,
961 window: Option<AnyWindowHandle>,
962 cx: &mut Context<Self>,
963 ) {
964 if self.remaining_turns == 0 {
965 return;
966 }
967
968 self.remaining_turns -= 1;
969
970 let request = self.to_completion_request(model.clone(), cx);
971
972 self.stream_completion(request, model, window, cx);
973 }
974
975 pub fn used_tools_since_last_user_message(&self) -> bool {
976 for message in self.messages.iter().rev() {
977 if self.tool_use.message_has_tool_results(message.id) {
978 return true;
979 } else if message.role == Role::User {
980 return false;
981 }
982 }
983
984 false
985 }
986
987 pub fn to_completion_request(
988 &self,
989 model: Arc<dyn LanguageModel>,
990 cx: &mut Context<Self>,
991 ) -> LanguageModelRequest {
992 let mut request = LanguageModelRequest {
993 thread_id: Some(self.id.to_string()),
994 prompt_id: Some(self.last_prompt_id.to_string()),
995 mode: None,
996 messages: vec![],
997 tools: Vec::new(),
998 stop: Vec::new(),
999 temperature: None,
1000 };
1001
1002 let available_tools = self.available_tools(cx, model.clone());
1003 let available_tool_names = available_tools
1004 .iter()
1005 .map(|tool| tool.name.clone())
1006 .collect();
1007
1008 let model_context = &ModelContext {
1009 available_tools: available_tool_names,
1010 };
1011
1012 if let Some(project_context) = self.project_context.borrow().as_ref() {
1013 match self
1014 .prompt_builder
1015 .generate_assistant_system_prompt(project_context, model_context)
1016 {
1017 Err(err) => {
1018 let message = format!("{err:?}").into();
1019 log::error!("{message}");
1020 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1021 header: "Error generating system prompt".into(),
1022 message,
1023 }));
1024 }
1025 Ok(system_prompt) => {
1026 request.messages.push(LanguageModelRequestMessage {
1027 role: Role::System,
1028 content: vec![MessageContent::Text(system_prompt)],
1029 cache: true,
1030 });
1031 }
1032 }
1033 } else {
1034 let message = "Context for system prompt unexpectedly not ready.".into();
1035 log::error!("{message}");
1036 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1037 header: "Error generating system prompt".into(),
1038 message,
1039 }));
1040 }
1041
1042 for message in &self.messages {
1043 let mut request_message = LanguageModelRequestMessage {
1044 role: message.role,
1045 content: Vec::new(),
1046 cache: false,
1047 };
1048
1049 message
1050 .loaded_context
1051 .add_to_request_message(&mut request_message);
1052
1053 for segment in &message.segments {
1054 match segment {
1055 MessageSegment::Text(text) => {
1056 if !text.is_empty() {
1057 request_message
1058 .content
1059 .push(MessageContent::Text(text.into()));
1060 }
1061 }
1062 MessageSegment::Thinking { text, signature } => {
1063 if !text.is_empty() {
1064 request_message.content.push(MessageContent::Thinking {
1065 text: text.into(),
1066 signature: signature.clone(),
1067 });
1068 }
1069 }
1070 MessageSegment::RedactedThinking(data) => {
1071 request_message
1072 .content
1073 .push(MessageContent::RedactedThinking(data.clone()));
1074 }
1075 };
1076 }
1077
1078 self.tool_use
1079 .attach_tool_uses(message.id, &mut request_message);
1080
1081 request.messages.push(request_message);
1082
1083 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1084 request.messages.push(tool_results_message);
1085 }
1086 }
1087
1088 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1089 if let Some(last) = request.messages.last_mut() {
1090 last.cache = true;
1091 }
1092
1093 self.attached_tracked_files_state(&mut request.messages, cx);
1094
1095 request.tools = available_tools;
1096 request.mode = if model.supports_max_mode() {
1097 self.completion_mode
1098 } else {
1099 None
1100 };
1101
1102 request
1103 }
1104
1105 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1106 let mut request = LanguageModelRequest {
1107 thread_id: None,
1108 prompt_id: None,
1109 mode: None,
1110 messages: vec![],
1111 tools: Vec::new(),
1112 stop: Vec::new(),
1113 temperature: None,
1114 };
1115
1116 for message in &self.messages {
1117 let mut request_message = LanguageModelRequestMessage {
1118 role: message.role,
1119 content: Vec::new(),
1120 cache: false,
1121 };
1122
1123 for segment in &message.segments {
1124 match segment {
1125 MessageSegment::Text(text) => request_message
1126 .content
1127 .push(MessageContent::Text(text.clone())),
1128 MessageSegment::Thinking { .. } => {}
1129 MessageSegment::RedactedThinking(_) => {}
1130 }
1131 }
1132
1133 if request_message.content.is_empty() {
1134 continue;
1135 }
1136
1137 request.messages.push(request_message);
1138 }
1139
1140 request.messages.push(LanguageModelRequestMessage {
1141 role: Role::User,
1142 content: vec![MessageContent::Text(added_user_message)],
1143 cache: false,
1144 });
1145
1146 request
1147 }
1148
1149 fn attached_tracked_files_state(
1150 &self,
1151 messages: &mut Vec<LanguageModelRequestMessage>,
1152 cx: &App,
1153 ) {
1154 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1155
1156 let mut stale_message = String::new();
1157
1158 let action_log = self.action_log.read(cx);
1159
1160 for stale_file in action_log.stale_buffers(cx) {
1161 let Some(file) = stale_file.read(cx).file() else {
1162 continue;
1163 };
1164
1165 if stale_message.is_empty() {
1166 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1167 }
1168
1169 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1170 }
1171
1172 let mut content = Vec::with_capacity(2);
1173
1174 if !stale_message.is_empty() {
1175 content.push(stale_message.into());
1176 }
1177
1178 if !content.is_empty() {
1179 let context_message = LanguageModelRequestMessage {
1180 role: Role::User,
1181 content,
1182 cache: false,
1183 };
1184
1185 messages.push(context_message);
1186 }
1187 }
1188
1189 pub fn stream_completion(
1190 &mut self,
1191 request: LanguageModelRequest,
1192 model: Arc<dyn LanguageModel>,
1193 window: Option<AnyWindowHandle>,
1194 cx: &mut Context<Self>,
1195 ) {
1196 let pending_completion_id = post_inc(&mut self.completion_count);
1197 let mut request_callback_parameters = if self.request_callback.is_some() {
1198 Some((request.clone(), Vec::new()))
1199 } else {
1200 None
1201 };
1202 let prompt_id = self.last_prompt_id.clone();
1203 let tool_use_metadata = ToolUseMetadata {
1204 model: model.clone(),
1205 thread_id: self.id.clone(),
1206 prompt_id: prompt_id.clone(),
1207 };
1208
1209 let task = cx.spawn(async move |thread, cx| {
1210 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1211 let initial_token_usage =
1212 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1213 let stream_completion = async {
1214 let (mut events, usage) = stream_completion_future.await?;
1215
1216 let mut stop_reason = StopReason::EndTurn;
1217 let mut current_token_usage = TokenUsage::default();
1218
1219 if let Some(usage) = usage {
1220 thread
1221 .update(cx, |_thread, cx| {
1222 cx.emit(ThreadEvent::UsageUpdated(usage));
1223 })
1224 .ok();
1225 }
1226
1227 let mut request_assistant_message_id = None;
1228
1229 while let Some(event) = events.next().await {
1230 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1231 response_events
1232 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1233 }
1234
1235 thread.update(cx, |thread, cx| {
1236 let event = match event {
1237 Ok(event) => event,
1238 Err(LanguageModelCompletionError::BadInputJson {
1239 id,
1240 tool_name,
1241 raw_input: invalid_input_json,
1242 json_parse_error,
1243 }) => {
1244 thread.receive_invalid_tool_json(
1245 id,
1246 tool_name,
1247 invalid_input_json,
1248 json_parse_error,
1249 window,
1250 cx,
1251 );
1252 return Ok(());
1253 }
1254 Err(LanguageModelCompletionError::Other(error)) => {
1255 return Err(error);
1256 }
1257 };
1258
1259 match event {
1260 LanguageModelCompletionEvent::StartMessage { .. } => {
1261 request_assistant_message_id =
1262 Some(thread.insert_assistant_message(
1263 vec![MessageSegment::Text(String::new())],
1264 cx,
1265 ));
1266 }
1267 LanguageModelCompletionEvent::Stop(reason) => {
1268 stop_reason = reason;
1269 }
1270 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1271 thread.update_token_usage_at_last_message(token_usage);
1272 thread.cumulative_token_usage = thread.cumulative_token_usage
1273 + token_usage
1274 - current_token_usage;
1275 current_token_usage = token_usage;
1276 }
1277 LanguageModelCompletionEvent::Text(chunk) => {
1278 cx.emit(ThreadEvent::ReceivedTextChunk);
1279 if let Some(last_message) = thread.messages.last_mut() {
1280 if last_message.role == Role::Assistant
1281 && !thread.tool_use.has_tool_results(last_message.id)
1282 {
1283 last_message.push_text(&chunk);
1284 cx.emit(ThreadEvent::StreamedAssistantText(
1285 last_message.id,
1286 chunk,
1287 ));
1288 } else {
1289 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1290 // of a new Assistant response.
1291 //
1292 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1293 // will result in duplicating the text of the chunk in the rendered Markdown.
1294 request_assistant_message_id =
1295 Some(thread.insert_assistant_message(
1296 vec![MessageSegment::Text(chunk.to_string())],
1297 cx,
1298 ));
1299 };
1300 }
1301 }
1302 LanguageModelCompletionEvent::Thinking {
1303 text: chunk,
1304 signature,
1305 } => {
1306 if let Some(last_message) = thread.messages.last_mut() {
1307 if last_message.role == Role::Assistant
1308 && !thread.tool_use.has_tool_results(last_message.id)
1309 {
1310 last_message.push_thinking(&chunk, signature);
1311 cx.emit(ThreadEvent::StreamedAssistantThinking(
1312 last_message.id,
1313 chunk,
1314 ));
1315 } else {
1316 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1317 // of a new Assistant response.
1318 //
1319 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1320 // will result in duplicating the text of the chunk in the rendered Markdown.
1321 request_assistant_message_id =
1322 Some(thread.insert_assistant_message(
1323 vec![MessageSegment::Thinking {
1324 text: chunk.to_string(),
1325 signature,
1326 }],
1327 cx,
1328 ));
1329 };
1330 }
1331 }
1332 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1333 let last_assistant_message_id = request_assistant_message_id
1334 .unwrap_or_else(|| {
1335 let new_assistant_message_id =
1336 thread.insert_assistant_message(vec![], cx);
1337 request_assistant_message_id =
1338 Some(new_assistant_message_id);
1339 new_assistant_message_id
1340 });
1341
1342 let tool_use_id = tool_use.id.clone();
1343 let streamed_input = if tool_use.is_input_complete {
1344 None
1345 } else {
1346 Some((&tool_use.input).clone())
1347 };
1348
1349 let ui_text = thread.tool_use.request_tool_use(
1350 last_assistant_message_id,
1351 tool_use,
1352 tool_use_metadata.clone(),
1353 cx,
1354 );
1355
1356 if let Some(input) = streamed_input {
1357 cx.emit(ThreadEvent::StreamedToolUse {
1358 tool_use_id,
1359 ui_text,
1360 input,
1361 });
1362 }
1363 }
1364 }
1365
1366 thread.touch_updated_at();
1367 cx.emit(ThreadEvent::StreamedCompletion);
1368 cx.notify();
1369
1370 thread.auto_capture_telemetry(cx);
1371 Ok(())
1372 })??;
1373
1374 smol::future::yield_now().await;
1375 }
1376
1377 thread.update(cx, |thread, cx| {
1378 thread
1379 .pending_completions
1380 .retain(|completion| completion.id != pending_completion_id);
1381
1382 // If there is a response without tool use, summarize the message. Otherwise,
1383 // allow two tool uses before summarizing.
1384 if thread.summary.is_none()
1385 && thread.messages.len() >= 2
1386 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1387 {
1388 thread.summarize(cx);
1389 }
1390 })?;
1391
1392 anyhow::Ok(stop_reason)
1393 };
1394
1395 let result = stream_completion.await;
1396
1397 thread
1398 .update(cx, |thread, cx| {
1399 thread.finalize_pending_checkpoint(cx);
1400 match result.as_ref() {
1401 Ok(stop_reason) => match stop_reason {
1402 StopReason::ToolUse => {
1403 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1404 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1405 }
1406 StopReason::EndTurn => {}
1407 StopReason::MaxTokens => {}
1408 },
1409 Err(error) => {
1410 if error.is::<PaymentRequiredError>() {
1411 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1412 } else if error.is::<MaxMonthlySpendReachedError>() {
1413 cx.emit(ThreadEvent::ShowError(
1414 ThreadError::MaxMonthlySpendReached,
1415 ));
1416 } else if let Some(error) =
1417 error.downcast_ref::<ModelRequestLimitReachedError>()
1418 {
1419 cx.emit(ThreadEvent::ShowError(
1420 ThreadError::ModelRequestLimitReached { plan: error.plan },
1421 ));
1422 } else if let Some(known_error) =
1423 error.downcast_ref::<LanguageModelKnownError>()
1424 {
1425 match known_error {
1426 LanguageModelKnownError::ContextWindowLimitExceeded {
1427 tokens,
1428 } => {
1429 thread.exceeded_window_error = Some(ExceededWindowError {
1430 model_id: model.id(),
1431 token_count: *tokens,
1432 });
1433 cx.notify();
1434 }
1435 }
1436 } else {
1437 let error_message = error
1438 .chain()
1439 .map(|err| err.to_string())
1440 .collect::<Vec<_>>()
1441 .join("\n");
1442 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1443 header: "Error interacting with language model".into(),
1444 message: SharedString::from(error_message.clone()),
1445 }));
1446 }
1447
1448 thread.cancel_last_completion(window, cx);
1449 }
1450 }
1451 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1452
1453 if let Some((request_callback, (request, response_events))) = thread
1454 .request_callback
1455 .as_mut()
1456 .zip(request_callback_parameters.as_ref())
1457 {
1458 request_callback(request, response_events);
1459 }
1460
1461 thread.auto_capture_telemetry(cx);
1462
1463 if let Ok(initial_usage) = initial_token_usage {
1464 let usage = thread.cumulative_token_usage - initial_usage;
1465
1466 telemetry::event!(
1467 "Assistant Thread Completion",
1468 thread_id = thread.id().to_string(),
1469 prompt_id = prompt_id,
1470 model = model.telemetry_id(),
1471 model_provider = model.provider_id().to_string(),
1472 input_tokens = usage.input_tokens,
1473 output_tokens = usage.output_tokens,
1474 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1475 cache_read_input_tokens = usage.cache_read_input_tokens,
1476 );
1477 }
1478 })
1479 .ok();
1480 });
1481
1482 self.pending_completions.push(PendingCompletion {
1483 id: pending_completion_id,
1484 _task: task,
1485 });
1486 }
1487
1488 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1489 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1490 return;
1491 };
1492
1493 if !model.provider.is_authenticated(cx) {
1494 return;
1495 }
1496
1497 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1498 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1499 If the conversation is about a specific subject, include it in the title. \
1500 Be descriptive. DO NOT speak in the first person.";
1501
1502 let request = self.to_summarize_request(added_user_message.into());
1503
1504 self.pending_summary = cx.spawn(async move |this, cx| {
1505 async move {
1506 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1507 let (mut messages, usage) = stream.await?;
1508
1509 if let Some(usage) = usage {
1510 this.update(cx, |_thread, cx| {
1511 cx.emit(ThreadEvent::UsageUpdated(usage));
1512 })
1513 .ok();
1514 }
1515
1516 let mut new_summary = String::new();
1517 while let Some(message) = messages.stream.next().await {
1518 let text = message?;
1519 let mut lines = text.lines();
1520 new_summary.extend(lines.next());
1521
1522 // Stop if the LLM generated multiple lines.
1523 if lines.next().is_some() {
1524 break;
1525 }
1526 }
1527
1528 this.update(cx, |this, cx| {
1529 if !new_summary.is_empty() {
1530 this.summary = Some(new_summary.into());
1531 }
1532
1533 cx.emit(ThreadEvent::SummaryGenerated);
1534 })?;
1535
1536 anyhow::Ok(())
1537 }
1538 .log_err()
1539 .await
1540 });
1541 }
1542
1543 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1544 let last_message_id = self.messages.last().map(|message| message.id)?;
1545
1546 match &self.detailed_summary_state {
1547 DetailedSummaryState::Generating { message_id, .. }
1548 | DetailedSummaryState::Generated { message_id, .. }
1549 if *message_id == last_message_id =>
1550 {
1551 // Already up-to-date
1552 return None;
1553 }
1554 _ => {}
1555 }
1556
1557 let ConfiguredModel { model, provider } =
1558 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1559
1560 if !provider.is_authenticated(cx) {
1561 return None;
1562 }
1563
1564 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1565 1. A brief overview of what was discussed\n\
1566 2. Key facts or information discovered\n\
1567 3. Outcomes or conclusions reached\n\
1568 4. Any action items or next steps if any\n\
1569 Format it in Markdown with headings and bullet points.";
1570
1571 let request = self.to_summarize_request(added_user_message.into());
1572
1573 let task = cx.spawn(async move |thread, cx| {
1574 let stream = model.stream_completion_text(request, &cx);
1575 let Some(mut messages) = stream.await.log_err() else {
1576 thread
1577 .update(cx, |this, _cx| {
1578 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1579 })
1580 .log_err();
1581
1582 return;
1583 };
1584
1585 let mut new_detailed_summary = String::new();
1586
1587 while let Some(chunk) = messages.stream.next().await {
1588 if let Some(chunk) = chunk.log_err() {
1589 new_detailed_summary.push_str(&chunk);
1590 }
1591 }
1592
1593 thread
1594 .update(cx, |this, _cx| {
1595 this.detailed_summary_state = DetailedSummaryState::Generated {
1596 text: new_detailed_summary.into(),
1597 message_id: last_message_id,
1598 };
1599 })
1600 .log_err();
1601 });
1602
1603 self.detailed_summary_state = DetailedSummaryState::Generating {
1604 message_id: last_message_id,
1605 };
1606
1607 Some(task)
1608 }
1609
1610 pub fn is_generating_detailed_summary(&self) -> bool {
1611 matches!(
1612 self.detailed_summary_state,
1613 DetailedSummaryState::Generating { .. }
1614 )
1615 }
1616
1617 pub fn use_pending_tools(
1618 &mut self,
1619 window: Option<AnyWindowHandle>,
1620 cx: &mut Context<Self>,
1621 model: Arc<dyn LanguageModel>,
1622 ) -> Vec<PendingToolUse> {
1623 self.auto_capture_telemetry(cx);
1624 let request = self.to_completion_request(model, cx);
1625 let messages = Arc::new(request.messages);
1626 let pending_tool_uses = self
1627 .tool_use
1628 .pending_tool_uses()
1629 .into_iter()
1630 .filter(|tool_use| tool_use.status.is_idle())
1631 .cloned()
1632 .collect::<Vec<_>>();
1633
1634 for tool_use in pending_tool_uses.iter() {
1635 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1636 if tool.needs_confirmation(&tool_use.input, cx)
1637 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1638 {
1639 self.tool_use.confirm_tool_use(
1640 tool_use.id.clone(),
1641 tool_use.ui_text.clone(),
1642 tool_use.input.clone(),
1643 messages.clone(),
1644 tool,
1645 );
1646 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1647 } else {
1648 self.run_tool(
1649 tool_use.id.clone(),
1650 tool_use.ui_text.clone(),
1651 tool_use.input.clone(),
1652 &messages,
1653 tool,
1654 window,
1655 cx,
1656 );
1657 }
1658 }
1659 }
1660
1661 pending_tool_uses
1662 }
1663
1664 pub fn receive_invalid_tool_json(
1665 &mut self,
1666 tool_use_id: LanguageModelToolUseId,
1667 tool_name: Arc<str>,
1668 invalid_json: Arc<str>,
1669 error: String,
1670 window: Option<AnyWindowHandle>,
1671 cx: &mut Context<Thread>,
1672 ) {
1673 log::error!("The model returned invalid input JSON: {invalid_json}");
1674
1675 let pending_tool_use = self.tool_use.insert_tool_output(
1676 tool_use_id.clone(),
1677 tool_name,
1678 Err(anyhow!("Error parsing input JSON: {error}")),
1679 cx,
1680 );
1681 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1682 pending_tool_use.ui_text.clone()
1683 } else {
1684 log::error!(
1685 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1686 );
1687 format!("Unknown tool {}", tool_use_id).into()
1688 };
1689
1690 cx.emit(ThreadEvent::InvalidToolInput {
1691 tool_use_id: tool_use_id.clone(),
1692 ui_text,
1693 invalid_input_json: invalid_json,
1694 });
1695
1696 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1697 }
1698
1699 pub fn run_tool(
1700 &mut self,
1701 tool_use_id: LanguageModelToolUseId,
1702 ui_text: impl Into<SharedString>,
1703 input: serde_json::Value,
1704 messages: &[LanguageModelRequestMessage],
1705 tool: Arc<dyn Tool>,
1706 window: Option<AnyWindowHandle>,
1707 cx: &mut Context<Thread>,
1708 ) {
1709 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1710 self.tool_use
1711 .run_pending_tool(tool_use_id, ui_text.into(), task);
1712 }
1713
1714 fn spawn_tool_use(
1715 &mut self,
1716 tool_use_id: LanguageModelToolUseId,
1717 messages: &[LanguageModelRequestMessage],
1718 input: serde_json::Value,
1719 tool: Arc<dyn Tool>,
1720 window: Option<AnyWindowHandle>,
1721 cx: &mut Context<Thread>,
1722 ) -> Task<()> {
1723 let tool_name: Arc<str> = tool.name().into();
1724
1725 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1726 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1727 } else {
1728 tool.run(
1729 input,
1730 messages,
1731 self.project.clone(),
1732 self.action_log.clone(),
1733 window,
1734 cx,
1735 )
1736 };
1737
1738 // Store the card separately if it exists
1739 if let Some(card) = tool_result.card.clone() {
1740 self.tool_use
1741 .insert_tool_result_card(tool_use_id.clone(), card);
1742 }
1743
1744 cx.spawn({
1745 async move |thread: WeakEntity<Thread>, cx| {
1746 let output = tool_result.output.await;
1747
1748 thread
1749 .update(cx, |thread, cx| {
1750 let pending_tool_use = thread.tool_use.insert_tool_output(
1751 tool_use_id.clone(),
1752 tool_name,
1753 output,
1754 cx,
1755 );
1756 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1757 })
1758 .ok();
1759 }
1760 })
1761 }
1762
1763 fn tool_finished(
1764 &mut self,
1765 tool_use_id: LanguageModelToolUseId,
1766 pending_tool_use: Option<PendingToolUse>,
1767 canceled: bool,
1768 window: Option<AnyWindowHandle>,
1769 cx: &mut Context<Self>,
1770 ) {
1771 if self.all_tools_finished() {
1772 let model_registry = LanguageModelRegistry::read_global(cx);
1773 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1774 if !canceled {
1775 self.send_to_model(model, window, cx);
1776 }
1777 self.auto_capture_telemetry(cx);
1778 }
1779 }
1780
1781 cx.emit(ThreadEvent::ToolFinished {
1782 tool_use_id,
1783 pending_tool_use,
1784 });
1785 }
1786
1787 /// Cancels the last pending completion, if there are any pending.
1788 ///
1789 /// Returns whether a completion was canceled.
1790 pub fn cancel_last_completion(
1791 &mut self,
1792 window: Option<AnyWindowHandle>,
1793 cx: &mut Context<Self>,
1794 ) -> bool {
1795 let mut canceled = self.pending_completions.pop().is_some();
1796
1797 for pending_tool_use in self.tool_use.cancel_pending() {
1798 canceled = true;
1799 self.tool_finished(
1800 pending_tool_use.id.clone(),
1801 Some(pending_tool_use),
1802 true,
1803 window,
1804 cx,
1805 );
1806 }
1807
1808 self.finalize_pending_checkpoint(cx);
1809 canceled
1810 }
1811
1812 pub fn feedback(&self) -> Option<ThreadFeedback> {
1813 self.feedback
1814 }
1815
1816 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1817 self.message_feedback.get(&message_id).copied()
1818 }
1819
1820 pub fn report_message_feedback(
1821 &mut self,
1822 message_id: MessageId,
1823 feedback: ThreadFeedback,
1824 cx: &mut Context<Self>,
1825 ) -> Task<Result<()>> {
1826 if self.message_feedback.get(&message_id) == Some(&feedback) {
1827 return Task::ready(Ok(()));
1828 }
1829
1830 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1831 let serialized_thread = self.serialize(cx);
1832 let thread_id = self.id().clone();
1833 let client = self.project.read(cx).client();
1834
1835 let enabled_tool_names: Vec<String> = self
1836 .tools()
1837 .read(cx)
1838 .enabled_tools(cx)
1839 .iter()
1840 .map(|tool| tool.name().to_string())
1841 .collect();
1842
1843 self.message_feedback.insert(message_id, feedback);
1844
1845 cx.notify();
1846
1847 let message_content = self
1848 .message(message_id)
1849 .map(|msg| msg.to_string())
1850 .unwrap_or_default();
1851
1852 cx.background_spawn(async move {
1853 let final_project_snapshot = final_project_snapshot.await;
1854 let serialized_thread = serialized_thread.await?;
1855 let thread_data =
1856 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1857
1858 let rating = match feedback {
1859 ThreadFeedback::Positive => "positive",
1860 ThreadFeedback::Negative => "negative",
1861 };
1862 telemetry::event!(
1863 "Assistant Thread Rated",
1864 rating,
1865 thread_id,
1866 enabled_tool_names,
1867 message_id = message_id.0,
1868 message_content,
1869 thread_data,
1870 final_project_snapshot
1871 );
1872 client.telemetry().flush_events().await;
1873
1874 Ok(())
1875 })
1876 }
1877
1878 pub fn report_feedback(
1879 &mut self,
1880 feedback: ThreadFeedback,
1881 cx: &mut Context<Self>,
1882 ) -> Task<Result<()>> {
1883 let last_assistant_message_id = self
1884 .messages
1885 .iter()
1886 .rev()
1887 .find(|msg| msg.role == Role::Assistant)
1888 .map(|msg| msg.id);
1889
1890 if let Some(message_id) = last_assistant_message_id {
1891 self.report_message_feedback(message_id, feedback, cx)
1892 } else {
1893 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1894 let serialized_thread = self.serialize(cx);
1895 let thread_id = self.id().clone();
1896 let client = self.project.read(cx).client();
1897 self.feedback = Some(feedback);
1898 cx.notify();
1899
1900 cx.background_spawn(async move {
1901 let final_project_snapshot = final_project_snapshot.await;
1902 let serialized_thread = serialized_thread.await?;
1903 let thread_data = serde_json::to_value(serialized_thread)
1904 .unwrap_or_else(|_| serde_json::Value::Null);
1905
1906 let rating = match feedback {
1907 ThreadFeedback::Positive => "positive",
1908 ThreadFeedback::Negative => "negative",
1909 };
1910 telemetry::event!(
1911 "Assistant Thread Rated",
1912 rating,
1913 thread_id,
1914 thread_data,
1915 final_project_snapshot
1916 );
1917 client.telemetry().flush_events().await;
1918
1919 Ok(())
1920 })
1921 }
1922 }
1923
1924 /// Create a snapshot of the current project state including git information and unsaved buffers.
1925 fn project_snapshot(
1926 project: Entity<Project>,
1927 cx: &mut Context<Self>,
1928 ) -> Task<Arc<ProjectSnapshot>> {
1929 let git_store = project.read(cx).git_store().clone();
1930 let worktree_snapshots: Vec<_> = project
1931 .read(cx)
1932 .visible_worktrees(cx)
1933 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1934 .collect();
1935
1936 cx.spawn(async move |_, cx| {
1937 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1938
1939 let mut unsaved_buffers = Vec::new();
1940 cx.update(|app_cx| {
1941 let buffer_store = project.read(app_cx).buffer_store();
1942 for buffer_handle in buffer_store.read(app_cx).buffers() {
1943 let buffer = buffer_handle.read(app_cx);
1944 if buffer.is_dirty() {
1945 if let Some(file) = buffer.file() {
1946 let path = file.path().to_string_lossy().to_string();
1947 unsaved_buffers.push(path);
1948 }
1949 }
1950 }
1951 })
1952 .ok();
1953
1954 Arc::new(ProjectSnapshot {
1955 worktree_snapshots,
1956 unsaved_buffer_paths: unsaved_buffers,
1957 timestamp: Utc::now(),
1958 })
1959 })
1960 }
1961
1962 fn worktree_snapshot(
1963 worktree: Entity<project::Worktree>,
1964 git_store: Entity<GitStore>,
1965 cx: &App,
1966 ) -> Task<WorktreeSnapshot> {
1967 cx.spawn(async move |cx| {
1968 // Get worktree path and snapshot
1969 let worktree_info = cx.update(|app_cx| {
1970 let worktree = worktree.read(app_cx);
1971 let path = worktree.abs_path().to_string_lossy().to_string();
1972 let snapshot = worktree.snapshot();
1973 (path, snapshot)
1974 });
1975
1976 let Ok((worktree_path, _snapshot)) = worktree_info else {
1977 return WorktreeSnapshot {
1978 worktree_path: String::new(),
1979 git_state: None,
1980 };
1981 };
1982
1983 let git_state = git_store
1984 .update(cx, |git_store, cx| {
1985 git_store
1986 .repositories()
1987 .values()
1988 .find(|repo| {
1989 repo.read(cx)
1990 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1991 .is_some()
1992 })
1993 .cloned()
1994 })
1995 .ok()
1996 .flatten()
1997 .map(|repo| {
1998 repo.update(cx, |repo, _| {
1999 let current_branch =
2000 repo.branch.as_ref().map(|branch| branch.name.to_string());
2001 repo.send_job(None, |state, _| async move {
2002 let RepositoryState::Local { backend, .. } = state else {
2003 return GitState {
2004 remote_url: None,
2005 head_sha: None,
2006 current_branch,
2007 diff: None,
2008 };
2009 };
2010
2011 let remote_url = backend.remote_url("origin");
2012 let head_sha = backend.head_sha().await;
2013 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2014
2015 GitState {
2016 remote_url,
2017 head_sha,
2018 current_branch,
2019 diff,
2020 }
2021 })
2022 })
2023 });
2024
2025 let git_state = match git_state {
2026 Some(git_state) => match git_state.ok() {
2027 Some(git_state) => git_state.await.ok(),
2028 None => None,
2029 },
2030 None => None,
2031 };
2032
2033 WorktreeSnapshot {
2034 worktree_path,
2035 git_state,
2036 }
2037 })
2038 }
2039
2040 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2041 let mut markdown = Vec::new();
2042
2043 if let Some(summary) = self.summary() {
2044 writeln!(markdown, "# {summary}\n")?;
2045 };
2046
2047 for message in self.messages() {
2048 writeln!(
2049 markdown,
2050 "## {role}\n",
2051 role = match message.role {
2052 Role::User => "User",
2053 Role::Assistant => "Assistant",
2054 Role::System => "System",
2055 }
2056 )?;
2057
2058 if !message.loaded_context.text.is_empty() {
2059 writeln!(markdown, "{}", message.loaded_context.text)?;
2060 }
2061
2062 if !message.loaded_context.images.is_empty() {
2063 writeln!(
2064 markdown,
2065 "\n{} images attached as context.\n",
2066 message.loaded_context.images.len()
2067 )?;
2068 }
2069
2070 for segment in &message.segments {
2071 match segment {
2072 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2073 MessageSegment::Thinking { text, .. } => {
2074 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2075 }
2076 MessageSegment::RedactedThinking(_) => {}
2077 }
2078 }
2079
2080 for tool_use in self.tool_uses_for_message(message.id, cx) {
2081 writeln!(
2082 markdown,
2083 "**Use Tool: {} ({})**",
2084 tool_use.name, tool_use.id
2085 )?;
2086 writeln!(markdown, "```json")?;
2087 writeln!(
2088 markdown,
2089 "{}",
2090 serde_json::to_string_pretty(&tool_use.input)?
2091 )?;
2092 writeln!(markdown, "```")?;
2093 }
2094
2095 for tool_result in self.tool_results_for_message(message.id) {
2096 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2097 if tool_result.is_error {
2098 write!(markdown, " (Error)")?;
2099 }
2100
2101 writeln!(markdown, "**\n")?;
2102 writeln!(markdown, "{}", tool_result.content)?;
2103 }
2104 }
2105
2106 Ok(String::from_utf8_lossy(&markdown).to_string())
2107 }
2108
2109 pub fn keep_edits_in_range(
2110 &mut self,
2111 buffer: Entity<language::Buffer>,
2112 buffer_range: Range<language::Anchor>,
2113 cx: &mut Context<Self>,
2114 ) {
2115 self.action_log.update(cx, |action_log, cx| {
2116 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2117 });
2118 }
2119
2120 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2121 self.action_log
2122 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2123 }
2124
2125 pub fn reject_edits_in_ranges(
2126 &mut self,
2127 buffer: Entity<language::Buffer>,
2128 buffer_ranges: Vec<Range<language::Anchor>>,
2129 cx: &mut Context<Self>,
2130 ) -> Task<Result<()>> {
2131 self.action_log.update(cx, |action_log, cx| {
2132 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2133 })
2134 }
2135
2136 pub fn action_log(&self) -> &Entity<ActionLog> {
2137 &self.action_log
2138 }
2139
2140 pub fn project(&self) -> &Entity<Project> {
2141 &self.project
2142 }
2143
2144 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2145 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2146 return;
2147 }
2148
2149 let now = Instant::now();
2150 if let Some(last) = self.last_auto_capture_at {
2151 if now.duration_since(last).as_secs() < 10 {
2152 return;
2153 }
2154 }
2155
2156 self.last_auto_capture_at = Some(now);
2157
2158 let thread_id = self.id().clone();
2159 let github_login = self
2160 .project
2161 .read(cx)
2162 .user_store()
2163 .read(cx)
2164 .current_user()
2165 .map(|user| user.github_login.clone());
2166 let client = self.project.read(cx).client().clone();
2167 let serialize_task = self.serialize(cx);
2168
2169 cx.background_executor()
2170 .spawn(async move {
2171 if let Ok(serialized_thread) = serialize_task.await {
2172 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2173 telemetry::event!(
2174 "Agent Thread Auto-Captured",
2175 thread_id = thread_id.to_string(),
2176 thread_data = thread_data,
2177 auto_capture_reason = "tracked_user",
2178 github_login = github_login
2179 );
2180
2181 client.telemetry().flush_events().await;
2182 }
2183 }
2184 })
2185 .detach();
2186 }
2187
2188 pub fn cumulative_token_usage(&self) -> TokenUsage {
2189 self.cumulative_token_usage
2190 }
2191
2192 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2193 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2194 return TotalTokenUsage::default();
2195 };
2196
2197 let max = model.model.max_token_count();
2198
2199 let index = self
2200 .messages
2201 .iter()
2202 .position(|msg| msg.id == message_id)
2203 .unwrap_or(0);
2204
2205 if index == 0 {
2206 return TotalTokenUsage { total: 0, max };
2207 }
2208
2209 let token_usage = &self
2210 .request_token_usage
2211 .get(index - 1)
2212 .cloned()
2213 .unwrap_or_default();
2214
2215 TotalTokenUsage {
2216 total: token_usage.total_tokens() as usize,
2217 max,
2218 }
2219 }
2220
2221 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2222 let model_registry = LanguageModelRegistry::read_global(cx);
2223 let Some(model) = model_registry.default_model() else {
2224 return TotalTokenUsage::default();
2225 };
2226
2227 let max = model.model.max_token_count();
2228
2229 if let Some(exceeded_error) = &self.exceeded_window_error {
2230 if model.model.id() == exceeded_error.model_id {
2231 return TotalTokenUsage {
2232 total: exceeded_error.token_count,
2233 max,
2234 };
2235 }
2236 }
2237
2238 let total = self
2239 .token_usage_at_last_message()
2240 .unwrap_or_default()
2241 .total_tokens() as usize;
2242
2243 TotalTokenUsage { total, max }
2244 }
2245
2246 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2247 self.request_token_usage
2248 .get(self.messages.len().saturating_sub(1))
2249 .or_else(|| self.request_token_usage.last())
2250 .cloned()
2251 }
2252
2253 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2254 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2255 self.request_token_usage
2256 .resize(self.messages.len(), placeholder);
2257
2258 if let Some(last) = self.request_token_usage.last_mut() {
2259 *last = token_usage;
2260 }
2261 }
2262
2263 pub fn deny_tool_use(
2264 &mut self,
2265 tool_use_id: LanguageModelToolUseId,
2266 tool_name: Arc<str>,
2267 window: Option<AnyWindowHandle>,
2268 cx: &mut Context<Self>,
2269 ) {
2270 let err = Err(anyhow::anyhow!(
2271 "Permission to run tool action denied by user"
2272 ));
2273
2274 self.tool_use
2275 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2276 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2277 }
2278}
2279
2280#[derive(Debug, Clone, Error)]
2281pub enum ThreadError {
2282 #[error("Payment required")]
2283 PaymentRequired,
2284 #[error("Max monthly spend reached")]
2285 MaxMonthlySpendReached,
2286 #[error("Model request limit reached")]
2287 ModelRequestLimitReached { plan: Plan },
2288 #[error("Message {header}: {message}")]
2289 Message {
2290 header: SharedString,
2291 message: SharedString,
2292 },
2293}
2294
2295#[derive(Debug, Clone)]
2296pub enum ThreadEvent {
2297 ShowError(ThreadError),
2298 UsageUpdated(RequestUsage),
2299 StreamedCompletion,
2300 ReceivedTextChunk,
2301 StreamedAssistantText(MessageId, String),
2302 StreamedAssistantThinking(MessageId, String),
2303 StreamedToolUse {
2304 tool_use_id: LanguageModelToolUseId,
2305 ui_text: Arc<str>,
2306 input: serde_json::Value,
2307 },
2308 InvalidToolInput {
2309 tool_use_id: LanguageModelToolUseId,
2310 ui_text: Arc<str>,
2311 invalid_input_json: Arc<str>,
2312 },
2313 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2314 MessageAdded(MessageId),
2315 MessageEdited(MessageId),
2316 MessageDeleted(MessageId),
2317 SummaryGenerated,
2318 SummaryChanged,
2319 UsePendingTools {
2320 tool_uses: Vec<PendingToolUse>,
2321 },
2322 ToolFinished {
2323 #[allow(unused)]
2324 tool_use_id: LanguageModelToolUseId,
2325 /// The pending tool use that corresponds to this tool.
2326 pending_tool_use: Option<PendingToolUse>,
2327 },
2328 CheckpointChanged,
2329 ToolConfirmationNeeded,
2330}
2331
2332impl EventEmitter<ThreadEvent> for Thread {}
2333
2334struct PendingCompletion {
2335 id: usize,
2336 _task: Task<()>,
2337}
2338
2339#[cfg(test)]
2340mod tests {
2341 use super::*;
2342 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2343 use assistant_settings::AssistantSettings;
2344 use assistant_tool::ToolRegistry;
2345 use context_server::ContextServerSettings;
2346 use editor::EditorSettings;
2347 use gpui::TestAppContext;
2348 use language_model::fake_provider::FakeLanguageModel;
2349 use project::{FakeFs, Project};
2350 use prompt_store::PromptBuilder;
2351 use serde_json::json;
2352 use settings::{Settings, SettingsStore};
2353 use std::sync::Arc;
2354 use theme::ThemeSettings;
2355 use util::path;
2356 use workspace::Workspace;
2357
2358 #[gpui::test]
2359 async fn test_message_with_context(cx: &mut TestAppContext) {
2360 init_test_settings(cx);
2361
2362 let project = create_test_project(
2363 cx,
2364 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2365 )
2366 .await;
2367
2368 let (_workspace, _thread_store, thread, context_store, model) =
2369 setup_test_environment(cx, project.clone()).await;
2370
2371 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2372 .await
2373 .unwrap();
2374
2375 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2376 let loaded_context = cx
2377 .update(|cx| load_context(vec![context], &project, &None, cx))
2378 .await;
2379
2380 // Insert user message with context
2381 let message_id = thread.update(cx, |thread, cx| {
2382 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2383 });
2384
2385 // Check content and context in message object
2386 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2387
2388 // Use different path format strings based on platform for the test
2389 #[cfg(windows)]
2390 let path_part = r"test\code.rs";
2391 #[cfg(not(windows))]
2392 let path_part = "test/code.rs";
2393
2394 let expected_context = format!(
2395 r#"
2396<context>
2397The following items were attached by the user. You don't need to use other tools to read them.
2398
2399<files>
2400```rs {path_part}
2401fn main() {{
2402 println!("Hello, world!");
2403}}
2404```
2405</files>
2406</context>
2407"#
2408 );
2409
2410 assert_eq!(message.role, Role::User);
2411 assert_eq!(message.segments.len(), 1);
2412 assert_eq!(
2413 message.segments[0],
2414 MessageSegment::Text("Please explain this code".to_string())
2415 );
2416 assert_eq!(message.loaded_context.text, expected_context);
2417
2418 // Check message in request
2419 let request = thread.update(cx, |thread, cx| {
2420 thread.to_completion_request(model.clone(), cx)
2421 });
2422
2423 assert_eq!(request.messages.len(), 2);
2424 let expected_full_message = format!("{}Please explain this code", expected_context);
2425 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2426 }
2427
2428 #[gpui::test]
2429 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2430 init_test_settings(cx);
2431
2432 let project = create_test_project(
2433 cx,
2434 json!({
2435 "file1.rs": "fn function1() {}\n",
2436 "file2.rs": "fn function2() {}\n",
2437 "file3.rs": "fn function3() {}\n",
2438 }),
2439 )
2440 .await;
2441
2442 let (_, _thread_store, thread, context_store, model) =
2443 setup_test_environment(cx, project.clone()).await;
2444
2445 // First message with context 1
2446 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2447 .await
2448 .unwrap();
2449 let new_contexts = context_store.update(cx, |store, cx| {
2450 store.new_context_for_thread(thread.read(cx))
2451 });
2452 assert_eq!(new_contexts.len(), 1);
2453 let loaded_context = cx
2454 .update(|cx| load_context(new_contexts, &project, &None, cx))
2455 .await;
2456 let message1_id = thread.update(cx, |thread, cx| {
2457 thread.insert_user_message("Message 1", loaded_context, None, cx)
2458 });
2459
2460 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2461 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2462 .await
2463 .unwrap();
2464 let new_contexts = context_store.update(cx, |store, cx| {
2465 store.new_context_for_thread(thread.read(cx))
2466 });
2467 assert_eq!(new_contexts.len(), 1);
2468 let loaded_context = cx
2469 .update(|cx| load_context(new_contexts, &project, &None, cx))
2470 .await;
2471 let message2_id = thread.update(cx, |thread, cx| {
2472 thread.insert_user_message("Message 2", loaded_context, None, cx)
2473 });
2474
2475 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2476 //
2477 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2478 .await
2479 .unwrap();
2480 let new_contexts = context_store.update(cx, |store, cx| {
2481 store.new_context_for_thread(thread.read(cx))
2482 });
2483 assert_eq!(new_contexts.len(), 1);
2484 let loaded_context = cx
2485 .update(|cx| load_context(new_contexts, &project, &None, cx))
2486 .await;
2487 let message3_id = thread.update(cx, |thread, cx| {
2488 thread.insert_user_message("Message 3", loaded_context, None, cx)
2489 });
2490
2491 // Check what contexts are included in each message
2492 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2493 (
2494 thread.message(message1_id).unwrap().clone(),
2495 thread.message(message2_id).unwrap().clone(),
2496 thread.message(message3_id).unwrap().clone(),
2497 )
2498 });
2499
2500 // First message should include context 1
2501 assert!(message1.loaded_context.text.contains("file1.rs"));
2502
2503 // Second message should include only context 2 (not 1)
2504 assert!(!message2.loaded_context.text.contains("file1.rs"));
2505 assert!(message2.loaded_context.text.contains("file2.rs"));
2506
2507 // Third message should include only context 3 (not 1 or 2)
2508 assert!(!message3.loaded_context.text.contains("file1.rs"));
2509 assert!(!message3.loaded_context.text.contains("file2.rs"));
2510 assert!(message3.loaded_context.text.contains("file3.rs"));
2511
2512 // Check entire request to make sure all contexts are properly included
2513 let request = thread.update(cx, |thread, cx| {
2514 thread.to_completion_request(model.clone(), cx)
2515 });
2516
2517 // The request should contain all 3 messages
2518 assert_eq!(request.messages.len(), 4);
2519
2520 // Check that the contexts are properly formatted in each message
2521 assert!(request.messages[1].string_contents().contains("file1.rs"));
2522 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2523 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2524
2525 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2526 assert!(request.messages[2].string_contents().contains("file2.rs"));
2527 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2528
2529 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2530 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2531 assert!(request.messages[3].string_contents().contains("file3.rs"));
2532 }
2533
2534 #[gpui::test]
2535 async fn test_message_without_files(cx: &mut TestAppContext) {
2536 init_test_settings(cx);
2537
2538 let project = create_test_project(
2539 cx,
2540 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2541 )
2542 .await;
2543
2544 let (_, _thread_store, thread, _context_store, model) =
2545 setup_test_environment(cx, project.clone()).await;
2546
2547 // Insert user message without any context (empty context vector)
2548 let message_id = thread.update(cx, |thread, cx| {
2549 thread.insert_user_message(
2550 "What is the best way to learn Rust?",
2551 ContextLoadResult::default(),
2552 None,
2553 cx,
2554 )
2555 });
2556
2557 // Check content and context in message object
2558 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2559
2560 // Context should be empty when no files are included
2561 assert_eq!(message.role, Role::User);
2562 assert_eq!(message.segments.len(), 1);
2563 assert_eq!(
2564 message.segments[0],
2565 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2566 );
2567 assert_eq!(message.loaded_context.text, "");
2568
2569 // Check message in request
2570 let request = thread.update(cx, |thread, cx| {
2571 thread.to_completion_request(model.clone(), cx)
2572 });
2573
2574 assert_eq!(request.messages.len(), 2);
2575 assert_eq!(
2576 request.messages[1].string_contents(),
2577 "What is the best way to learn Rust?"
2578 );
2579
2580 // Add second message, also without context
2581 let message2_id = thread.update(cx, |thread, cx| {
2582 thread.insert_user_message(
2583 "Are there any good books?",
2584 ContextLoadResult::default(),
2585 None,
2586 cx,
2587 )
2588 });
2589
2590 let message2 =
2591 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2592 assert_eq!(message2.loaded_context.text, "");
2593
2594 // Check that both messages appear in the request
2595 let request = thread.update(cx, |thread, cx| {
2596 thread.to_completion_request(model.clone(), cx)
2597 });
2598
2599 assert_eq!(request.messages.len(), 3);
2600 assert_eq!(
2601 request.messages[1].string_contents(),
2602 "What is the best way to learn Rust?"
2603 );
2604 assert_eq!(
2605 request.messages[2].string_contents(),
2606 "Are there any good books?"
2607 );
2608 }
2609
2610 #[gpui::test]
2611 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2612 init_test_settings(cx);
2613
2614 let project = create_test_project(
2615 cx,
2616 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2617 )
2618 .await;
2619
2620 let (_workspace, _thread_store, thread, context_store, model) =
2621 setup_test_environment(cx, project.clone()).await;
2622
2623 // Open buffer and add it to context
2624 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2625 .await
2626 .unwrap();
2627
2628 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2629 let loaded_context = cx
2630 .update(|cx| load_context(vec![context], &project, &None, cx))
2631 .await;
2632
2633 // Insert user message with the buffer as context
2634 thread.update(cx, |thread, cx| {
2635 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2636 });
2637
2638 // Create a request and check that it doesn't have a stale buffer warning yet
2639 let initial_request = thread.update(cx, |thread, cx| {
2640 thread.to_completion_request(model.clone(), cx)
2641 });
2642
2643 // Make sure we don't have a stale file warning yet
2644 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2645 msg.string_contents()
2646 .contains("These files changed since last read:")
2647 });
2648 assert!(
2649 !has_stale_warning,
2650 "Should not have stale buffer warning before buffer is modified"
2651 );
2652
2653 // Modify the buffer
2654 buffer.update(cx, |buffer, cx| {
2655 // Find a position at the end of line 1
2656 buffer.edit(
2657 [(1..1, "\n println!(\"Added a new line\");\n")],
2658 None,
2659 cx,
2660 );
2661 });
2662
2663 // Insert another user message without context
2664 thread.update(cx, |thread, cx| {
2665 thread.insert_user_message(
2666 "What does the code do now?",
2667 ContextLoadResult::default(),
2668 None,
2669 cx,
2670 )
2671 });
2672
2673 // Create a new request and check for the stale buffer warning
2674 let new_request = thread.update(cx, |thread, cx| {
2675 thread.to_completion_request(model.clone(), cx)
2676 });
2677
2678 // We should have a stale file warning as the last message
2679 let last_message = new_request
2680 .messages
2681 .last()
2682 .expect("Request should have messages");
2683
2684 // The last message should be the stale buffer notification
2685 assert_eq!(last_message.role, Role::User);
2686
2687 // Check the exact content of the message
2688 let expected_content = "These files changed since last read:\n- code.rs\n";
2689 assert_eq!(
2690 last_message.string_contents(),
2691 expected_content,
2692 "Last message should be exactly the stale buffer notification"
2693 );
2694 }
2695
2696 fn init_test_settings(cx: &mut TestAppContext) {
2697 cx.update(|cx| {
2698 let settings_store = SettingsStore::test(cx);
2699 cx.set_global(settings_store);
2700 language::init(cx);
2701 Project::init_settings(cx);
2702 AssistantSettings::register(cx);
2703 prompt_store::init(cx);
2704 thread_store::init(cx);
2705 workspace::init_settings(cx);
2706 ThemeSettings::register(cx);
2707 ContextServerSettings::register(cx);
2708 EditorSettings::register(cx);
2709 ToolRegistry::default_global(cx);
2710 });
2711 }
2712
2713 // Helper to create a test project with test files
2714 async fn create_test_project(
2715 cx: &mut TestAppContext,
2716 files: serde_json::Value,
2717 ) -> Entity<Project> {
2718 let fs = FakeFs::new(cx.executor());
2719 fs.insert_tree(path!("/test"), files).await;
2720 Project::test(fs, [path!("/test").as_ref()], cx).await
2721 }
2722
2723 async fn setup_test_environment(
2724 cx: &mut TestAppContext,
2725 project: Entity<Project>,
2726 ) -> (
2727 Entity<Workspace>,
2728 Entity<ThreadStore>,
2729 Entity<Thread>,
2730 Entity<ContextStore>,
2731 Arc<dyn LanguageModel>,
2732 ) {
2733 let (workspace, cx) =
2734 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2735
2736 let thread_store = cx
2737 .update(|_, cx| {
2738 ThreadStore::load(
2739 project.clone(),
2740 cx.new(|_| ToolWorkingSet::default()),
2741 None,
2742 Arc::new(PromptBuilder::new(None).unwrap()),
2743 cx,
2744 )
2745 })
2746 .await
2747 .unwrap();
2748
2749 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2750 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2751
2752 let model = FakeLanguageModel::default();
2753 let model: Arc<dyn LanguageModel> = Arc::new(model);
2754
2755 (workspace, thread_store, thread, context_store, model)
2756 }
2757
2758 async fn add_file_to_context(
2759 project: &Entity<Project>,
2760 context_store: &Entity<ContextStore>,
2761 path: &str,
2762 cx: &mut TestAppContext,
2763 ) -> Result<Entity<language::Buffer>> {
2764 let buffer_path = project
2765 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2766 .unwrap();
2767
2768 let buffer = project
2769 .update(cx, |project, cx| {
2770 project.open_buffer(buffer_path.clone(), cx)
2771 })
2772 .await
2773 .unwrap();
2774
2775 context_store.update(cx, |context_store, cx| {
2776 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2777 });
2778
2779 Ok(buffer)
2780 }
2781}