1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone, Debug)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(text)
885 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
886 Some(text)
887 }
888 LanguageModelToolResultContent::Image(_) => {
889 // TODO: We should display image
890 None
891 }
892 }
893 }
894
895 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
896 self.tool_use.tool_result_card(id).cloned()
897 }
898
899 /// Return tools that are both enabled and supported by the model
900 pub fn available_tools(
901 &self,
902 cx: &App,
903 model: Arc<dyn LanguageModel>,
904 ) -> Vec<LanguageModelRequestTool> {
905 if model.supports_tools() {
906 self.tools()
907 .read(cx)
908 .enabled_tools(cx)
909 .into_iter()
910 .filter_map(|tool| {
911 // Skip tools that cannot be supported
912 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
913 Some(LanguageModelRequestTool {
914 name: tool.name(),
915 description: tool.description(),
916 input_schema,
917 })
918 })
919 .collect()
920 } else {
921 Vec::default()
922 }
923 }
924
925 pub fn insert_user_message(
926 &mut self,
927 text: impl Into<String>,
928 loaded_context: ContextLoadResult,
929 git_checkpoint: Option<GitStoreCheckpoint>,
930 creases: Vec<MessageCrease>,
931 cx: &mut Context<Self>,
932 ) -> MessageId {
933 if !loaded_context.referenced_buffers.is_empty() {
934 self.action_log.update(cx, |log, cx| {
935 for buffer in loaded_context.referenced_buffers {
936 log.buffer_read(buffer, cx);
937 }
938 });
939 }
940
941 let message_id = self.insert_message(
942 Role::User,
943 vec![MessageSegment::Text(text.into())],
944 loaded_context.loaded_context,
945 creases,
946 cx,
947 );
948
949 if let Some(git_checkpoint) = git_checkpoint {
950 self.pending_checkpoint = Some(ThreadCheckpoint {
951 message_id,
952 git_checkpoint,
953 });
954 }
955
956 self.auto_capture_telemetry(cx);
957
958 message_id
959 }
960
961 pub fn insert_assistant_message(
962 &mut self,
963 segments: Vec<MessageSegment>,
964 cx: &mut Context<Self>,
965 ) -> MessageId {
966 self.insert_message(
967 Role::Assistant,
968 segments,
969 LoadedContext::default(),
970 Vec::new(),
971 cx,
972 )
973 }
974
975 pub fn insert_message(
976 &mut self,
977 role: Role,
978 segments: Vec<MessageSegment>,
979 loaded_context: LoadedContext,
980 creases: Vec<MessageCrease>,
981 cx: &mut Context<Self>,
982 ) -> MessageId {
983 let id = self.next_message_id.post_inc();
984 self.messages.push(Message {
985 id,
986 role,
987 segments,
988 loaded_context,
989 creases,
990 });
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageAdded(id));
993 id
994 }
995
996 pub fn edit_message(
997 &mut self,
998 id: MessageId,
999 new_role: Role,
1000 new_segments: Vec<MessageSegment>,
1001 loaded_context: Option<LoadedContext>,
1002 checkpoint: Option<GitStoreCheckpoint>,
1003 cx: &mut Context<Self>,
1004 ) -> bool {
1005 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1006 return false;
1007 };
1008 message.role = new_role;
1009 message.segments = new_segments;
1010 if let Some(context) = loaded_context {
1011 message.loaded_context = context;
1012 }
1013 if let Some(git_checkpoint) = checkpoint {
1014 self.checkpoints_by_message.insert(
1015 id,
1016 ThreadCheckpoint {
1017 message_id: id,
1018 git_checkpoint,
1019 },
1020 );
1021 }
1022 self.touch_updated_at();
1023 cx.emit(ThreadEvent::MessageEdited(id));
1024 true
1025 }
1026
1027 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1028 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1029 return false;
1030 };
1031 self.messages.remove(index);
1032 self.touch_updated_at();
1033 cx.emit(ThreadEvent::MessageDeleted(id));
1034 true
1035 }
1036
1037 /// Returns the representation of this [`Thread`] in a textual form.
1038 ///
1039 /// This is the representation we use when attaching a thread as context to another thread.
1040 pub fn text(&self) -> String {
1041 let mut text = String::new();
1042
1043 for message in &self.messages {
1044 text.push_str(match message.role {
1045 language_model::Role::User => "User:",
1046 language_model::Role::Assistant => "Agent:",
1047 language_model::Role::System => "System:",
1048 });
1049 text.push('\n');
1050
1051 for segment in &message.segments {
1052 match segment {
1053 MessageSegment::Text(content) => text.push_str(content),
1054 MessageSegment::Thinking { text: content, .. } => {
1055 text.push_str(&format!("<think>{}</think>", content))
1056 }
1057 MessageSegment::RedactedThinking(_) => {}
1058 }
1059 }
1060 text.push('\n');
1061 }
1062
1063 text
1064 }
1065
1066 /// Serializes this thread into a format for storage or telemetry.
1067 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1068 let initial_project_snapshot = self.initial_project_snapshot.clone();
1069 cx.spawn(async move |this, cx| {
1070 let initial_project_snapshot = initial_project_snapshot.await;
1071 this.read_with(cx, |this, cx| SerializedThread {
1072 version: SerializedThread::VERSION.to_string(),
1073 summary: this.summary().or_default(),
1074 updated_at: this.updated_at(),
1075 messages: this
1076 .messages()
1077 .map(|message| SerializedMessage {
1078 id: message.id,
1079 role: message.role,
1080 segments: message
1081 .segments
1082 .iter()
1083 .map(|segment| match segment {
1084 MessageSegment::Text(text) => {
1085 SerializedMessageSegment::Text { text: text.clone() }
1086 }
1087 MessageSegment::Thinking { text, signature } => {
1088 SerializedMessageSegment::Thinking {
1089 text: text.clone(),
1090 signature: signature.clone(),
1091 }
1092 }
1093 MessageSegment::RedactedThinking(data) => {
1094 SerializedMessageSegment::RedactedThinking {
1095 data: data.clone(),
1096 }
1097 }
1098 })
1099 .collect(),
1100 tool_uses: this
1101 .tool_uses_for_message(message.id, cx)
1102 .into_iter()
1103 .map(|tool_use| SerializedToolUse {
1104 id: tool_use.id,
1105 name: tool_use.name,
1106 input: tool_use.input,
1107 })
1108 .collect(),
1109 tool_results: this
1110 .tool_results_for_message(message.id)
1111 .into_iter()
1112 .map(|tool_result| SerializedToolResult {
1113 tool_use_id: tool_result.tool_use_id.clone(),
1114 is_error: tool_result.is_error,
1115 content: tool_result.content.clone(),
1116 output: tool_result.output.clone(),
1117 })
1118 .collect(),
1119 context: message.loaded_context.text.clone(),
1120 creases: message
1121 .creases
1122 .iter()
1123 .map(|crease| SerializedCrease {
1124 start: crease.range.start,
1125 end: crease.range.end,
1126 icon_path: crease.metadata.icon_path.clone(),
1127 label: crease.metadata.label.clone(),
1128 })
1129 .collect(),
1130 })
1131 .collect(),
1132 initial_project_snapshot,
1133 cumulative_token_usage: this.cumulative_token_usage,
1134 request_token_usage: this.request_token_usage.clone(),
1135 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1136 exceeded_window_error: this.exceeded_window_error.clone(),
1137 model: this
1138 .configured_model
1139 .as_ref()
1140 .map(|model| SerializedLanguageModel {
1141 provider: model.provider.id().0.to_string(),
1142 model: model.model.id().0.to_string(),
1143 }),
1144 completion_mode: Some(this.completion_mode),
1145 })
1146 })
1147 }
1148
1149 pub fn remaining_turns(&self) -> u32 {
1150 self.remaining_turns
1151 }
1152
1153 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1154 self.remaining_turns = remaining_turns;
1155 }
1156
1157 pub fn send_to_model(
1158 &mut self,
1159 model: Arc<dyn LanguageModel>,
1160 window: Option<AnyWindowHandle>,
1161 cx: &mut Context<Self>,
1162 ) {
1163 if self.remaining_turns == 0 {
1164 return;
1165 }
1166
1167 self.remaining_turns -= 1;
1168
1169 let request = self.to_completion_request(model.clone(), cx);
1170
1171 self.stream_completion(request, model, window, cx);
1172 }
1173
1174 pub fn used_tools_since_last_user_message(&self) -> bool {
1175 for message in self.messages.iter().rev() {
1176 if self.tool_use.message_has_tool_results(message.id) {
1177 return true;
1178 } else if message.role == Role::User {
1179 return false;
1180 }
1181 }
1182
1183 false
1184 }
1185
1186 pub fn to_completion_request(
1187 &self,
1188 model: Arc<dyn LanguageModel>,
1189 cx: &mut Context<Self>,
1190 ) -> LanguageModelRequest {
1191 let mut request = LanguageModelRequest {
1192 thread_id: Some(self.id.to_string()),
1193 prompt_id: Some(self.last_prompt_id.to_string()),
1194 mode: None,
1195 messages: vec![],
1196 tools: Vec::new(),
1197 tool_choice: None,
1198 stop: Vec::new(),
1199 temperature: AssistantSettings::temperature_for_model(&model, cx),
1200 };
1201
1202 let available_tools = self.available_tools(cx, model.clone());
1203 let available_tool_names = available_tools
1204 .iter()
1205 .map(|tool| tool.name.clone())
1206 .collect();
1207
1208 let model_context = &ModelContext {
1209 available_tools: available_tool_names,
1210 };
1211
1212 if let Some(project_context) = self.project_context.borrow().as_ref() {
1213 match self
1214 .prompt_builder
1215 .generate_assistant_system_prompt(project_context, model_context)
1216 {
1217 Err(err) => {
1218 let message = format!("{err:?}").into();
1219 log::error!("{message}");
1220 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1221 header: "Error generating system prompt".into(),
1222 message,
1223 }));
1224 }
1225 Ok(system_prompt) => {
1226 request.messages.push(LanguageModelRequestMessage {
1227 role: Role::System,
1228 content: vec![MessageContent::Text(system_prompt)],
1229 cache: true,
1230 });
1231 }
1232 }
1233 } else {
1234 let message = "Context for system prompt unexpectedly not ready.".into();
1235 log::error!("{message}");
1236 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1237 header: "Error generating system prompt".into(),
1238 message,
1239 }));
1240 }
1241
1242 let mut message_ix_to_cache = None;
1243 for message in &self.messages {
1244 let mut request_message = LanguageModelRequestMessage {
1245 role: message.role,
1246 content: Vec::new(),
1247 cache: false,
1248 };
1249
1250 message
1251 .loaded_context
1252 .add_to_request_message(&mut request_message);
1253
1254 for segment in &message.segments {
1255 match segment {
1256 MessageSegment::Text(text) => {
1257 if !text.is_empty() {
1258 request_message
1259 .content
1260 .push(MessageContent::Text(text.into()));
1261 }
1262 }
1263 MessageSegment::Thinking { text, signature } => {
1264 if !text.is_empty() {
1265 request_message.content.push(MessageContent::Thinking {
1266 text: text.into(),
1267 signature: signature.clone(),
1268 });
1269 }
1270 }
1271 MessageSegment::RedactedThinking(data) => {
1272 request_message
1273 .content
1274 .push(MessageContent::RedactedThinking(data.clone()));
1275 }
1276 };
1277 }
1278
1279 let mut cache_message = true;
1280 let mut tool_results_message = LanguageModelRequestMessage {
1281 role: Role::User,
1282 content: Vec::new(),
1283 cache: false,
1284 };
1285 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1286 if let Some(tool_result) = tool_result {
1287 request_message
1288 .content
1289 .push(MessageContent::ToolUse(tool_use.clone()));
1290 tool_results_message
1291 .content
1292 .push(MessageContent::ToolResult(LanguageModelToolResult {
1293 tool_use_id: tool_use.id.clone(),
1294 tool_name: tool_result.tool_name.clone(),
1295 is_error: tool_result.is_error,
1296 content: if tool_result.content.is_empty() {
1297 // Surprisingly, the API fails if we return an empty string here.
1298 // It thinks we are sending a tool use without a tool result.
1299 "<Tool returned an empty string>".into()
1300 } else {
1301 tool_result.content.clone()
1302 },
1303 output: None,
1304 }));
1305 } else {
1306 cache_message = false;
1307 log::debug!(
1308 "skipped tool use {:?} because it is still pending",
1309 tool_use
1310 );
1311 }
1312 }
1313
1314 if cache_message {
1315 message_ix_to_cache = Some(request.messages.len());
1316 }
1317 request.messages.push(request_message);
1318
1319 if !tool_results_message.content.is_empty() {
1320 if cache_message {
1321 message_ix_to_cache = Some(request.messages.len());
1322 }
1323 request.messages.push(tool_results_message);
1324 }
1325 }
1326
1327 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1328 if let Some(message_ix_to_cache) = message_ix_to_cache {
1329 request.messages[message_ix_to_cache].cache = true;
1330 }
1331
1332 self.attached_tracked_files_state(&mut request.messages, cx);
1333
1334 request.tools = available_tools;
1335 request.mode = if model.supports_max_mode() {
1336 Some(self.completion_mode.into())
1337 } else {
1338 Some(CompletionMode::Normal.into())
1339 };
1340
1341 request
1342 }
1343
1344 fn to_summarize_request(
1345 &self,
1346 model: &Arc<dyn LanguageModel>,
1347 added_user_message: String,
1348 cx: &App,
1349 ) -> LanguageModelRequest {
1350 let mut request = LanguageModelRequest {
1351 thread_id: None,
1352 prompt_id: None,
1353 mode: None,
1354 messages: vec![],
1355 tools: Vec::new(),
1356 tool_choice: None,
1357 stop: Vec::new(),
1358 temperature: AssistantSettings::temperature_for_model(model, cx),
1359 };
1360
1361 for message in &self.messages {
1362 let mut request_message = LanguageModelRequestMessage {
1363 role: message.role,
1364 content: Vec::new(),
1365 cache: false,
1366 };
1367
1368 for segment in &message.segments {
1369 match segment {
1370 MessageSegment::Text(text) => request_message
1371 .content
1372 .push(MessageContent::Text(text.clone())),
1373 MessageSegment::Thinking { .. } => {}
1374 MessageSegment::RedactedThinking(_) => {}
1375 }
1376 }
1377
1378 if request_message.content.is_empty() {
1379 continue;
1380 }
1381
1382 request.messages.push(request_message);
1383 }
1384
1385 request.messages.push(LanguageModelRequestMessage {
1386 role: Role::User,
1387 content: vec![MessageContent::Text(added_user_message)],
1388 cache: false,
1389 });
1390
1391 request
1392 }
1393
1394 fn attached_tracked_files_state(
1395 &self,
1396 messages: &mut Vec<LanguageModelRequestMessage>,
1397 cx: &App,
1398 ) {
1399 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1400
1401 let mut stale_message = String::new();
1402
1403 let action_log = self.action_log.read(cx);
1404
1405 for stale_file in action_log.stale_buffers(cx) {
1406 let Some(file) = stale_file.read(cx).file() else {
1407 continue;
1408 };
1409
1410 if stale_message.is_empty() {
1411 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1412 }
1413
1414 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1415 }
1416
1417 let mut content = Vec::with_capacity(2);
1418
1419 if !stale_message.is_empty() {
1420 content.push(stale_message.into());
1421 }
1422
1423 if !content.is_empty() {
1424 let context_message = LanguageModelRequestMessage {
1425 role: Role::User,
1426 content,
1427 cache: false,
1428 };
1429
1430 messages.push(context_message);
1431 }
1432 }
1433
1434 pub fn stream_completion(
1435 &mut self,
1436 request: LanguageModelRequest,
1437 model: Arc<dyn LanguageModel>,
1438 window: Option<AnyWindowHandle>,
1439 cx: &mut Context<Self>,
1440 ) {
1441 self.tool_use_limit_reached = false;
1442
1443 let pending_completion_id = post_inc(&mut self.completion_count);
1444 let mut request_callback_parameters = if self.request_callback.is_some() {
1445 Some((request.clone(), Vec::new()))
1446 } else {
1447 None
1448 };
1449 let prompt_id = self.last_prompt_id.clone();
1450 let tool_use_metadata = ToolUseMetadata {
1451 model: model.clone(),
1452 thread_id: self.id.clone(),
1453 prompt_id: prompt_id.clone(),
1454 };
1455
1456 self.last_received_chunk_at = Some(Instant::now());
1457
1458 let task = cx.spawn(async move |thread, cx| {
1459 let stream_completion_future = model.stream_completion(request, &cx);
1460 let initial_token_usage =
1461 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1462 let stream_completion = async {
1463 let mut events = stream_completion_future.await?;
1464
1465 let mut stop_reason = StopReason::EndTurn;
1466 let mut current_token_usage = TokenUsage::default();
1467
1468 thread
1469 .update(cx, |_thread, cx| {
1470 cx.emit(ThreadEvent::NewRequest);
1471 })
1472 .ok();
1473
1474 let mut request_assistant_message_id = None;
1475
1476 while let Some(event) = events.next().await {
1477 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1478 response_events
1479 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1480 }
1481
1482 thread.update(cx, |thread, cx| {
1483 let event = match event {
1484 Ok(event) => event,
1485 Err(LanguageModelCompletionError::BadInputJson {
1486 id,
1487 tool_name,
1488 raw_input: invalid_input_json,
1489 json_parse_error,
1490 }) => {
1491 thread.receive_invalid_tool_json(
1492 id,
1493 tool_name,
1494 invalid_input_json,
1495 json_parse_error,
1496 window,
1497 cx,
1498 );
1499 return Ok(());
1500 }
1501 Err(LanguageModelCompletionError::Other(error)) => {
1502 return Err(error);
1503 }
1504 };
1505
1506 match event {
1507 LanguageModelCompletionEvent::StartMessage { .. } => {
1508 request_assistant_message_id =
1509 Some(thread.insert_assistant_message(
1510 vec![MessageSegment::Text(String::new())],
1511 cx,
1512 ));
1513 }
1514 LanguageModelCompletionEvent::Stop(reason) => {
1515 stop_reason = reason;
1516 }
1517 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1518 thread.update_token_usage_at_last_message(token_usage);
1519 thread.cumulative_token_usage = thread.cumulative_token_usage
1520 + token_usage
1521 - current_token_usage;
1522 current_token_usage = token_usage;
1523 }
1524 LanguageModelCompletionEvent::Text(chunk) => {
1525 thread.received_chunk();
1526
1527 cx.emit(ThreadEvent::ReceivedTextChunk);
1528 if let Some(last_message) = thread.messages.last_mut() {
1529 if last_message.role == Role::Assistant
1530 && !thread.tool_use.has_tool_results(last_message.id)
1531 {
1532 last_message.push_text(&chunk);
1533 cx.emit(ThreadEvent::StreamedAssistantText(
1534 last_message.id,
1535 chunk,
1536 ));
1537 } else {
1538 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1539 // of a new Assistant response.
1540 //
1541 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1542 // will result in duplicating the text of the chunk in the rendered Markdown.
1543 request_assistant_message_id =
1544 Some(thread.insert_assistant_message(
1545 vec![MessageSegment::Text(chunk.to_string())],
1546 cx,
1547 ));
1548 };
1549 }
1550 }
1551 LanguageModelCompletionEvent::Thinking {
1552 text: chunk,
1553 signature,
1554 } => {
1555 thread.received_chunk();
1556
1557 if let Some(last_message) = thread.messages.last_mut() {
1558 if last_message.role == Role::Assistant
1559 && !thread.tool_use.has_tool_results(last_message.id)
1560 {
1561 last_message.push_thinking(&chunk, signature);
1562 cx.emit(ThreadEvent::StreamedAssistantThinking(
1563 last_message.id,
1564 chunk,
1565 ));
1566 } else {
1567 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1568 // of a new Assistant response.
1569 //
1570 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1571 // will result in duplicating the text of the chunk in the rendered Markdown.
1572 request_assistant_message_id =
1573 Some(thread.insert_assistant_message(
1574 vec![MessageSegment::Thinking {
1575 text: chunk.to_string(),
1576 signature,
1577 }],
1578 cx,
1579 ));
1580 };
1581 }
1582 }
1583 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1584 let last_assistant_message_id = request_assistant_message_id
1585 .unwrap_or_else(|| {
1586 let new_assistant_message_id =
1587 thread.insert_assistant_message(vec![], cx);
1588 request_assistant_message_id =
1589 Some(new_assistant_message_id);
1590 new_assistant_message_id
1591 });
1592
1593 let tool_use_id = tool_use.id.clone();
1594 let streamed_input = if tool_use.is_input_complete {
1595 None
1596 } else {
1597 Some((&tool_use.input).clone())
1598 };
1599
1600 let ui_text = thread.tool_use.request_tool_use(
1601 last_assistant_message_id,
1602 tool_use,
1603 tool_use_metadata.clone(),
1604 cx,
1605 );
1606
1607 if let Some(input) = streamed_input {
1608 cx.emit(ThreadEvent::StreamedToolUse {
1609 tool_use_id,
1610 ui_text,
1611 input,
1612 });
1613 }
1614 }
1615 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1616 if let Some(completion) = thread
1617 .pending_completions
1618 .iter_mut()
1619 .find(|completion| completion.id == pending_completion_id)
1620 {
1621 match status_update {
1622 CompletionRequestStatus::Queued {
1623 position,
1624 } => {
1625 completion.queue_state = QueueState::Queued { position };
1626 }
1627 CompletionRequestStatus::Started => {
1628 completion.queue_state = QueueState::Started;
1629 }
1630 CompletionRequestStatus::Failed {
1631 code, message, request_id
1632 } => {
1633 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1634 }
1635 CompletionRequestStatus::UsageUpdated {
1636 amount, limit
1637 } => {
1638 let usage = RequestUsage { limit, amount: amount as i32 };
1639
1640 thread.last_usage = Some(usage);
1641 }
1642 CompletionRequestStatus::ToolUseLimitReached => {
1643 thread.tool_use_limit_reached = true;
1644 }
1645 }
1646 }
1647 }
1648 }
1649
1650 thread.touch_updated_at();
1651 cx.emit(ThreadEvent::StreamedCompletion);
1652 cx.notify();
1653
1654 thread.auto_capture_telemetry(cx);
1655 Ok(())
1656 })??;
1657
1658 smol::future::yield_now().await;
1659 }
1660
1661 thread.update(cx, |thread, cx| {
1662 thread.last_received_chunk_at = None;
1663 thread
1664 .pending_completions
1665 .retain(|completion| completion.id != pending_completion_id);
1666
1667 // If there is a response without tool use, summarize the message. Otherwise,
1668 // allow two tool uses before summarizing.
1669 if matches!(thread.summary, ThreadSummary::Pending)
1670 && thread.messages.len() >= 2
1671 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1672 {
1673 thread.summarize(cx);
1674 }
1675 })?;
1676
1677 anyhow::Ok(stop_reason)
1678 };
1679
1680 let result = stream_completion.await;
1681
1682 thread
1683 .update(cx, |thread, cx| {
1684 thread.finalize_pending_checkpoint(cx);
1685 match result.as_ref() {
1686 Ok(stop_reason) => match stop_reason {
1687 StopReason::ToolUse => {
1688 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1689 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1690 }
1691 StopReason::EndTurn | StopReason::MaxTokens => {
1692 thread.project.update(cx, |project, cx| {
1693 project.set_agent_location(None, cx);
1694 });
1695 }
1696 StopReason::Refusal => {
1697 thread.project.update(cx, |project, cx| {
1698 project.set_agent_location(None, cx);
1699 });
1700
1701 cx.emit (ThreadEvent::ShowError(ThreadError::Message {
1702 header: "Language model refusal".into(),
1703 message: "Model refused to generate content for safety reasons.".into(),
1704 }));
1705 }
1706 },
1707 Err(error) => {
1708 thread.project.update(cx, |project, cx| {
1709 project.set_agent_location(None, cx);
1710 });
1711
1712 if error.is::<PaymentRequiredError>() {
1713 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1714 } else if let Some(error) =
1715 error.downcast_ref::<ModelRequestLimitReachedError>()
1716 {
1717 cx.emit(ThreadEvent::ShowError(
1718 ThreadError::ModelRequestLimitReached { plan: error.plan },
1719 ));
1720 } else if let Some(known_error) =
1721 error.downcast_ref::<LanguageModelKnownError>()
1722 {
1723 match known_error {
1724 LanguageModelKnownError::ContextWindowLimitExceeded {
1725 tokens,
1726 } => {
1727 thread.exceeded_window_error = Some(ExceededWindowError {
1728 model_id: model.id(),
1729 token_count: *tokens,
1730 });
1731 cx.notify();
1732 }
1733 }
1734 } else {
1735 let error_message = error
1736 .chain()
1737 .map(|err| err.to_string())
1738 .collect::<Vec<_>>()
1739 .join("\n");
1740 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1741 header: "Error interacting with language model".into(),
1742 message: SharedString::from(error_message.clone()),
1743 }));
1744 }
1745
1746 thread.cancel_last_completion(window, cx);
1747 }
1748 }
1749 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1750
1751 if let Some((request_callback, (request, response_events))) = thread
1752 .request_callback
1753 .as_mut()
1754 .zip(request_callback_parameters.as_ref())
1755 {
1756 request_callback(request, response_events);
1757 }
1758
1759 thread.auto_capture_telemetry(cx);
1760
1761 if let Ok(initial_usage) = initial_token_usage {
1762 let usage = thread.cumulative_token_usage - initial_usage;
1763
1764 telemetry::event!(
1765 "Assistant Thread Completion",
1766 thread_id = thread.id().to_string(),
1767 prompt_id = prompt_id,
1768 model = model.telemetry_id(),
1769 model_provider = model.provider_id().to_string(),
1770 input_tokens = usage.input_tokens,
1771 output_tokens = usage.output_tokens,
1772 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1773 cache_read_input_tokens = usage.cache_read_input_tokens,
1774 );
1775 }
1776 })
1777 .ok();
1778 });
1779
1780 self.pending_completions.push(PendingCompletion {
1781 id: pending_completion_id,
1782 queue_state: QueueState::Sending,
1783 _task: task,
1784 });
1785 }
1786
1787 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1788 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1789 println!("No thread summary model");
1790 return;
1791 };
1792
1793 if !model.provider.is_authenticated(cx) {
1794 return;
1795 }
1796
1797 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1798 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1799 If the conversation is about a specific subject, include it in the title. \
1800 Be descriptive. DO NOT speak in the first person.";
1801
1802 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1803
1804 self.summary = ThreadSummary::Generating;
1805
1806 self.pending_summary = cx.spawn(async move |this, cx| {
1807 let result = async {
1808 let mut messages = model.model.stream_completion(request, &cx).await?;
1809
1810 let mut new_summary = String::new();
1811 while let Some(event) = messages.next().await {
1812 let Ok(event) = event else {
1813 continue;
1814 };
1815 let text = match event {
1816 LanguageModelCompletionEvent::Text(text) => text,
1817 LanguageModelCompletionEvent::StatusUpdate(
1818 CompletionRequestStatus::UsageUpdated { amount, limit },
1819 ) => {
1820 this.update(cx, |thread, _cx| {
1821 thread.last_usage = Some(RequestUsage {
1822 limit,
1823 amount: amount as i32,
1824 });
1825 })?;
1826 continue;
1827 }
1828 _ => continue,
1829 };
1830
1831 let mut lines = text.lines();
1832 new_summary.extend(lines.next());
1833
1834 // Stop if the LLM generated multiple lines.
1835 if lines.next().is_some() {
1836 break;
1837 }
1838 }
1839
1840 anyhow::Ok(new_summary)
1841 }
1842 .await;
1843
1844 this.update(cx, |this, cx| {
1845 match result {
1846 Ok(new_summary) => {
1847 if new_summary.is_empty() {
1848 this.summary = ThreadSummary::Error;
1849 } else {
1850 this.summary = ThreadSummary::Ready(new_summary.into());
1851 }
1852 }
1853 Err(err) => {
1854 this.summary = ThreadSummary::Error;
1855 log::error!("Failed to generate thread summary: {}", err);
1856 }
1857 }
1858 cx.emit(ThreadEvent::SummaryGenerated);
1859 })
1860 .log_err()?;
1861
1862 Some(())
1863 });
1864 }
1865
1866 pub fn start_generating_detailed_summary_if_needed(
1867 &mut self,
1868 thread_store: WeakEntity<ThreadStore>,
1869 cx: &mut Context<Self>,
1870 ) {
1871 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1872 return;
1873 };
1874
1875 match &*self.detailed_summary_rx.borrow() {
1876 DetailedSummaryState::Generating { message_id, .. }
1877 | DetailedSummaryState::Generated { message_id, .. }
1878 if *message_id == last_message_id =>
1879 {
1880 // Already up-to-date
1881 return;
1882 }
1883 _ => {}
1884 }
1885
1886 let Some(ConfiguredModel { model, provider }) =
1887 LanguageModelRegistry::read_global(cx).thread_summary_model()
1888 else {
1889 return;
1890 };
1891
1892 if !provider.is_authenticated(cx) {
1893 return;
1894 }
1895
1896 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1897 1. A brief overview of what was discussed\n\
1898 2. Key facts or information discovered\n\
1899 3. Outcomes or conclusions reached\n\
1900 4. Any action items or next steps if any\n\
1901 Format it in Markdown with headings and bullet points.";
1902
1903 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1904
1905 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1906 message_id: last_message_id,
1907 };
1908
1909 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1910 // be better to allow the old task to complete, but this would require logic for choosing
1911 // which result to prefer (the old task could complete after the new one, resulting in a
1912 // stale summary).
1913 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1914 let stream = model.stream_completion_text(request, &cx);
1915 let Some(mut messages) = stream.await.log_err() else {
1916 thread
1917 .update(cx, |thread, _cx| {
1918 *thread.detailed_summary_tx.borrow_mut() =
1919 DetailedSummaryState::NotGenerated;
1920 })
1921 .ok()?;
1922 return None;
1923 };
1924
1925 let mut new_detailed_summary = String::new();
1926
1927 while let Some(chunk) = messages.stream.next().await {
1928 if let Some(chunk) = chunk.log_err() {
1929 new_detailed_summary.push_str(&chunk);
1930 }
1931 }
1932
1933 thread
1934 .update(cx, |thread, _cx| {
1935 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1936 text: new_detailed_summary.into(),
1937 message_id: last_message_id,
1938 };
1939 })
1940 .ok()?;
1941
1942 // Save thread so its summary can be reused later
1943 if let Some(thread) = thread.upgrade() {
1944 if let Ok(Ok(save_task)) = cx.update(|cx| {
1945 thread_store
1946 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1947 }) {
1948 save_task.await.log_err();
1949 }
1950 }
1951
1952 Some(())
1953 });
1954 }
1955
1956 pub async fn wait_for_detailed_summary_or_text(
1957 this: &Entity<Self>,
1958 cx: &mut AsyncApp,
1959 ) -> Option<SharedString> {
1960 let mut detailed_summary_rx = this
1961 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1962 .ok()?;
1963 loop {
1964 match detailed_summary_rx.recv().await? {
1965 DetailedSummaryState::Generating { .. } => {}
1966 DetailedSummaryState::NotGenerated => {
1967 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1968 }
1969 DetailedSummaryState::Generated { text, .. } => return Some(text),
1970 }
1971 }
1972 }
1973
1974 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1975 self.detailed_summary_rx
1976 .borrow()
1977 .text()
1978 .unwrap_or_else(|| self.text().into())
1979 }
1980
1981 pub fn is_generating_detailed_summary(&self) -> bool {
1982 matches!(
1983 &*self.detailed_summary_rx.borrow(),
1984 DetailedSummaryState::Generating { .. }
1985 )
1986 }
1987
1988 pub fn use_pending_tools(
1989 &mut self,
1990 window: Option<AnyWindowHandle>,
1991 cx: &mut Context<Self>,
1992 model: Arc<dyn LanguageModel>,
1993 ) -> Vec<PendingToolUse> {
1994 self.auto_capture_telemetry(cx);
1995 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1996 let pending_tool_uses = self
1997 .tool_use
1998 .pending_tool_uses()
1999 .into_iter()
2000 .filter(|tool_use| tool_use.status.is_idle())
2001 .cloned()
2002 .collect::<Vec<_>>();
2003
2004 for tool_use in pending_tool_uses.iter() {
2005 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2006 if tool.needs_confirmation(&tool_use.input, cx)
2007 && !AssistantSettings::get_global(cx).always_allow_tool_actions
2008 {
2009 self.tool_use.confirm_tool_use(
2010 tool_use.id.clone(),
2011 tool_use.ui_text.clone(),
2012 tool_use.input.clone(),
2013 request.clone(),
2014 tool,
2015 );
2016 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2017 } else {
2018 self.run_tool(
2019 tool_use.id.clone(),
2020 tool_use.ui_text.clone(),
2021 tool_use.input.clone(),
2022 request.clone(),
2023 tool,
2024 model.clone(),
2025 window,
2026 cx,
2027 );
2028 }
2029 } else {
2030 self.handle_hallucinated_tool_use(
2031 tool_use.id.clone(),
2032 tool_use.name.clone(),
2033 window,
2034 cx,
2035 );
2036 }
2037 }
2038
2039 pending_tool_uses
2040 }
2041
2042 pub fn handle_hallucinated_tool_use(
2043 &mut self,
2044 tool_use_id: LanguageModelToolUseId,
2045 hallucinated_tool_name: Arc<str>,
2046 window: Option<AnyWindowHandle>,
2047 cx: &mut Context<Thread>,
2048 ) {
2049 let available_tools = self.tools.read(cx).enabled_tools(cx);
2050
2051 let tool_list = available_tools
2052 .iter()
2053 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2054 .collect::<Vec<_>>()
2055 .join("\n");
2056
2057 let error_message = format!(
2058 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2059 hallucinated_tool_name, tool_list
2060 );
2061
2062 let pending_tool_use = self.tool_use.insert_tool_output(
2063 tool_use_id.clone(),
2064 hallucinated_tool_name,
2065 Err(anyhow!("Missing tool call: {error_message}")),
2066 self.configured_model.as_ref(),
2067 );
2068
2069 cx.emit(ThreadEvent::MissingToolUse {
2070 tool_use_id: tool_use_id.clone(),
2071 ui_text: error_message.into(),
2072 });
2073
2074 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2075 }
2076
2077 pub fn receive_invalid_tool_json(
2078 &mut self,
2079 tool_use_id: LanguageModelToolUseId,
2080 tool_name: Arc<str>,
2081 invalid_json: Arc<str>,
2082 error: String,
2083 window: Option<AnyWindowHandle>,
2084 cx: &mut Context<Thread>,
2085 ) {
2086 log::error!("The model returned invalid input JSON: {invalid_json}");
2087
2088 let pending_tool_use = self.tool_use.insert_tool_output(
2089 tool_use_id.clone(),
2090 tool_name,
2091 Err(anyhow!("Error parsing input JSON: {error}")),
2092 self.configured_model.as_ref(),
2093 );
2094 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2095 pending_tool_use.ui_text.clone()
2096 } else {
2097 log::error!(
2098 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2099 );
2100 format!("Unknown tool {}", tool_use_id).into()
2101 };
2102
2103 cx.emit(ThreadEvent::InvalidToolInput {
2104 tool_use_id: tool_use_id.clone(),
2105 ui_text,
2106 invalid_input_json: invalid_json,
2107 });
2108
2109 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2110 }
2111
2112 pub fn run_tool(
2113 &mut self,
2114 tool_use_id: LanguageModelToolUseId,
2115 ui_text: impl Into<SharedString>,
2116 input: serde_json::Value,
2117 request: Arc<LanguageModelRequest>,
2118 tool: Arc<dyn Tool>,
2119 model: Arc<dyn LanguageModel>,
2120 window: Option<AnyWindowHandle>,
2121 cx: &mut Context<Thread>,
2122 ) {
2123 let task =
2124 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2125 self.tool_use
2126 .run_pending_tool(tool_use_id, ui_text.into(), task);
2127 }
2128
2129 fn spawn_tool_use(
2130 &mut self,
2131 tool_use_id: LanguageModelToolUseId,
2132 request: Arc<LanguageModelRequest>,
2133 input: serde_json::Value,
2134 tool: Arc<dyn Tool>,
2135 model: Arc<dyn LanguageModel>,
2136 window: Option<AnyWindowHandle>,
2137 cx: &mut Context<Thread>,
2138 ) -> Task<()> {
2139 let tool_name: Arc<str> = tool.name().into();
2140
2141 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2142 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2143 } else {
2144 tool.run(
2145 input,
2146 request,
2147 self.project.clone(),
2148 self.action_log.clone(),
2149 model,
2150 window,
2151 cx,
2152 )
2153 };
2154
2155 // Store the card separately if it exists
2156 if let Some(card) = tool_result.card.clone() {
2157 self.tool_use
2158 .insert_tool_result_card(tool_use_id.clone(), card);
2159 }
2160
2161 cx.spawn({
2162 async move |thread: WeakEntity<Thread>, cx| {
2163 let output = tool_result.output.await;
2164
2165 thread
2166 .update(cx, |thread, cx| {
2167 let pending_tool_use = thread.tool_use.insert_tool_output(
2168 tool_use_id.clone(),
2169 tool_name,
2170 output,
2171 thread.configured_model.as_ref(),
2172 );
2173 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2174 })
2175 .ok();
2176 }
2177 })
2178 }
2179
2180 fn tool_finished(
2181 &mut self,
2182 tool_use_id: LanguageModelToolUseId,
2183 pending_tool_use: Option<PendingToolUse>,
2184 canceled: bool,
2185 window: Option<AnyWindowHandle>,
2186 cx: &mut Context<Self>,
2187 ) {
2188 if self.all_tools_finished() {
2189 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2190 if !canceled {
2191 self.send_to_model(model.clone(), window, cx);
2192 }
2193 self.auto_capture_telemetry(cx);
2194 }
2195 }
2196
2197 cx.emit(ThreadEvent::ToolFinished {
2198 tool_use_id,
2199 pending_tool_use,
2200 });
2201 }
2202
2203 /// Cancels the last pending completion, if there are any pending.
2204 ///
2205 /// Returns whether a completion was canceled.
2206 pub fn cancel_last_completion(
2207 &mut self,
2208 window: Option<AnyWindowHandle>,
2209 cx: &mut Context<Self>,
2210 ) -> bool {
2211 let mut canceled = self.pending_completions.pop().is_some();
2212
2213 for pending_tool_use in self.tool_use.cancel_pending() {
2214 canceled = true;
2215 self.tool_finished(
2216 pending_tool_use.id.clone(),
2217 Some(pending_tool_use),
2218 true,
2219 window,
2220 cx,
2221 );
2222 }
2223
2224 self.finalize_pending_checkpoint(cx);
2225
2226 if canceled {
2227 cx.emit(ThreadEvent::CompletionCanceled);
2228 }
2229
2230 canceled
2231 }
2232
2233 /// Signals that any in-progress editing should be canceled.
2234 ///
2235 /// This method is used to notify listeners (like ActiveThread) that
2236 /// they should cancel any editing operations.
2237 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2238 cx.emit(ThreadEvent::CancelEditing);
2239 }
2240
2241 pub fn feedback(&self) -> Option<ThreadFeedback> {
2242 self.feedback
2243 }
2244
2245 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2246 self.message_feedback.get(&message_id).copied()
2247 }
2248
2249 pub fn report_message_feedback(
2250 &mut self,
2251 message_id: MessageId,
2252 feedback: ThreadFeedback,
2253 cx: &mut Context<Self>,
2254 ) -> Task<Result<()>> {
2255 if self.message_feedback.get(&message_id) == Some(&feedback) {
2256 return Task::ready(Ok(()));
2257 }
2258
2259 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2260 let serialized_thread = self.serialize(cx);
2261 let thread_id = self.id().clone();
2262 let client = self.project.read(cx).client();
2263
2264 let enabled_tool_names: Vec<String> = self
2265 .tools()
2266 .read(cx)
2267 .enabled_tools(cx)
2268 .iter()
2269 .map(|tool| tool.name())
2270 .collect();
2271
2272 self.message_feedback.insert(message_id, feedback);
2273
2274 cx.notify();
2275
2276 let message_content = self
2277 .message(message_id)
2278 .map(|msg| msg.to_string())
2279 .unwrap_or_default();
2280
2281 cx.background_spawn(async move {
2282 let final_project_snapshot = final_project_snapshot.await;
2283 let serialized_thread = serialized_thread.await?;
2284 let thread_data =
2285 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2286
2287 let rating = match feedback {
2288 ThreadFeedback::Positive => "positive",
2289 ThreadFeedback::Negative => "negative",
2290 };
2291 telemetry::event!(
2292 "Assistant Thread Rated",
2293 rating,
2294 thread_id,
2295 enabled_tool_names,
2296 message_id = message_id.0,
2297 message_content,
2298 thread_data,
2299 final_project_snapshot
2300 );
2301 client.telemetry().flush_events().await;
2302
2303 Ok(())
2304 })
2305 }
2306
2307 pub fn report_feedback(
2308 &mut self,
2309 feedback: ThreadFeedback,
2310 cx: &mut Context<Self>,
2311 ) -> Task<Result<()>> {
2312 let last_assistant_message_id = self
2313 .messages
2314 .iter()
2315 .rev()
2316 .find(|msg| msg.role == Role::Assistant)
2317 .map(|msg| msg.id);
2318
2319 if let Some(message_id) = last_assistant_message_id {
2320 self.report_message_feedback(message_id, feedback, cx)
2321 } else {
2322 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2323 let serialized_thread = self.serialize(cx);
2324 let thread_id = self.id().clone();
2325 let client = self.project.read(cx).client();
2326 self.feedback = Some(feedback);
2327 cx.notify();
2328
2329 cx.background_spawn(async move {
2330 let final_project_snapshot = final_project_snapshot.await;
2331 let serialized_thread = serialized_thread.await?;
2332 let thread_data = serde_json::to_value(serialized_thread)
2333 .unwrap_or_else(|_| serde_json::Value::Null);
2334
2335 let rating = match feedback {
2336 ThreadFeedback::Positive => "positive",
2337 ThreadFeedback::Negative => "negative",
2338 };
2339 telemetry::event!(
2340 "Assistant Thread Rated",
2341 rating,
2342 thread_id,
2343 thread_data,
2344 final_project_snapshot
2345 );
2346 client.telemetry().flush_events().await;
2347
2348 Ok(())
2349 })
2350 }
2351 }
2352
2353 /// Create a snapshot of the current project state including git information and unsaved buffers.
2354 fn project_snapshot(
2355 project: Entity<Project>,
2356 cx: &mut Context<Self>,
2357 ) -> Task<Arc<ProjectSnapshot>> {
2358 let git_store = project.read(cx).git_store().clone();
2359 let worktree_snapshots: Vec<_> = project
2360 .read(cx)
2361 .visible_worktrees(cx)
2362 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2363 .collect();
2364
2365 cx.spawn(async move |_, cx| {
2366 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2367
2368 let mut unsaved_buffers = Vec::new();
2369 cx.update(|app_cx| {
2370 let buffer_store = project.read(app_cx).buffer_store();
2371 for buffer_handle in buffer_store.read(app_cx).buffers() {
2372 let buffer = buffer_handle.read(app_cx);
2373 if buffer.is_dirty() {
2374 if let Some(file) = buffer.file() {
2375 let path = file.path().to_string_lossy().to_string();
2376 unsaved_buffers.push(path);
2377 }
2378 }
2379 }
2380 })
2381 .ok();
2382
2383 Arc::new(ProjectSnapshot {
2384 worktree_snapshots,
2385 unsaved_buffer_paths: unsaved_buffers,
2386 timestamp: Utc::now(),
2387 })
2388 })
2389 }
2390
2391 fn worktree_snapshot(
2392 worktree: Entity<project::Worktree>,
2393 git_store: Entity<GitStore>,
2394 cx: &App,
2395 ) -> Task<WorktreeSnapshot> {
2396 cx.spawn(async move |cx| {
2397 // Get worktree path and snapshot
2398 let worktree_info = cx.update(|app_cx| {
2399 let worktree = worktree.read(app_cx);
2400 let path = worktree.abs_path().to_string_lossy().to_string();
2401 let snapshot = worktree.snapshot();
2402 (path, snapshot)
2403 });
2404
2405 let Ok((worktree_path, _snapshot)) = worktree_info else {
2406 return WorktreeSnapshot {
2407 worktree_path: String::new(),
2408 git_state: None,
2409 };
2410 };
2411
2412 let git_state = git_store
2413 .update(cx, |git_store, cx| {
2414 git_store
2415 .repositories()
2416 .values()
2417 .find(|repo| {
2418 repo.read(cx)
2419 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2420 .is_some()
2421 })
2422 .cloned()
2423 })
2424 .ok()
2425 .flatten()
2426 .map(|repo| {
2427 repo.update(cx, |repo, _| {
2428 let current_branch =
2429 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2430 repo.send_job(None, |state, _| async move {
2431 let RepositoryState::Local { backend, .. } = state else {
2432 return GitState {
2433 remote_url: None,
2434 head_sha: None,
2435 current_branch,
2436 diff: None,
2437 };
2438 };
2439
2440 let remote_url = backend.remote_url("origin");
2441 let head_sha = backend.head_sha().await;
2442 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2443
2444 GitState {
2445 remote_url,
2446 head_sha,
2447 current_branch,
2448 diff,
2449 }
2450 })
2451 })
2452 });
2453
2454 let git_state = match git_state {
2455 Some(git_state) => match git_state.ok() {
2456 Some(git_state) => git_state.await.ok(),
2457 None => None,
2458 },
2459 None => None,
2460 };
2461
2462 WorktreeSnapshot {
2463 worktree_path,
2464 git_state,
2465 }
2466 })
2467 }
2468
2469 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2470 let mut markdown = Vec::new();
2471
2472 let summary = self.summary().or_default();
2473 writeln!(markdown, "# {summary}\n")?;
2474
2475 for message in self.messages() {
2476 writeln!(
2477 markdown,
2478 "## {role}\n",
2479 role = match message.role {
2480 Role::User => "User",
2481 Role::Assistant => "Agent",
2482 Role::System => "System",
2483 }
2484 )?;
2485
2486 if !message.loaded_context.text.is_empty() {
2487 writeln!(markdown, "{}", message.loaded_context.text)?;
2488 }
2489
2490 if !message.loaded_context.images.is_empty() {
2491 writeln!(
2492 markdown,
2493 "\n{} images attached as context.\n",
2494 message.loaded_context.images.len()
2495 )?;
2496 }
2497
2498 for segment in &message.segments {
2499 match segment {
2500 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2501 MessageSegment::Thinking { text, .. } => {
2502 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2503 }
2504 MessageSegment::RedactedThinking(_) => {}
2505 }
2506 }
2507
2508 for tool_use in self.tool_uses_for_message(message.id, cx) {
2509 writeln!(
2510 markdown,
2511 "**Use Tool: {} ({})**",
2512 tool_use.name, tool_use.id
2513 )?;
2514 writeln!(markdown, "```json")?;
2515 writeln!(
2516 markdown,
2517 "{}",
2518 serde_json::to_string_pretty(&tool_use.input)?
2519 )?;
2520 writeln!(markdown, "```")?;
2521 }
2522
2523 for tool_result in self.tool_results_for_message(message.id) {
2524 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2525 if tool_result.is_error {
2526 write!(markdown, " (Error)")?;
2527 }
2528
2529 writeln!(markdown, "**\n")?;
2530 match &tool_result.content {
2531 LanguageModelToolResultContent::Text(text)
2532 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2533 text,
2534 ..
2535 }) => {
2536 writeln!(markdown, "{text}")?;
2537 }
2538 LanguageModelToolResultContent::Image(image) => {
2539 writeln!(markdown, "", image.source)?;
2540 }
2541 }
2542
2543 if let Some(output) = tool_result.output.as_ref() {
2544 writeln!(
2545 markdown,
2546 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2547 serde_json::to_string_pretty(output)?
2548 )?;
2549 }
2550 }
2551 }
2552
2553 Ok(String::from_utf8_lossy(&markdown).to_string())
2554 }
2555
2556 pub fn keep_edits_in_range(
2557 &mut self,
2558 buffer: Entity<language::Buffer>,
2559 buffer_range: Range<language::Anchor>,
2560 cx: &mut Context<Self>,
2561 ) {
2562 self.action_log.update(cx, |action_log, cx| {
2563 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2564 });
2565 }
2566
2567 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2568 self.action_log
2569 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2570 }
2571
2572 pub fn reject_edits_in_ranges(
2573 &mut self,
2574 buffer: Entity<language::Buffer>,
2575 buffer_ranges: Vec<Range<language::Anchor>>,
2576 cx: &mut Context<Self>,
2577 ) -> Task<Result<()>> {
2578 self.action_log.update(cx, |action_log, cx| {
2579 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2580 })
2581 }
2582
2583 pub fn action_log(&self) -> &Entity<ActionLog> {
2584 &self.action_log
2585 }
2586
2587 pub fn project(&self) -> &Entity<Project> {
2588 &self.project
2589 }
2590
2591 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2592 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2593 return;
2594 }
2595
2596 let now = Instant::now();
2597 if let Some(last) = self.last_auto_capture_at {
2598 if now.duration_since(last).as_secs() < 10 {
2599 return;
2600 }
2601 }
2602
2603 self.last_auto_capture_at = Some(now);
2604
2605 let thread_id = self.id().clone();
2606 let github_login = self
2607 .project
2608 .read(cx)
2609 .user_store()
2610 .read(cx)
2611 .current_user()
2612 .map(|user| user.github_login.clone());
2613 let client = self.project.read(cx).client();
2614 let serialize_task = self.serialize(cx);
2615
2616 cx.background_executor()
2617 .spawn(async move {
2618 if let Ok(serialized_thread) = serialize_task.await {
2619 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2620 telemetry::event!(
2621 "Agent Thread Auto-Captured",
2622 thread_id = thread_id.to_string(),
2623 thread_data = thread_data,
2624 auto_capture_reason = "tracked_user",
2625 github_login = github_login
2626 );
2627
2628 client.telemetry().flush_events().await;
2629 }
2630 }
2631 })
2632 .detach();
2633 }
2634
2635 pub fn cumulative_token_usage(&self) -> TokenUsage {
2636 self.cumulative_token_usage
2637 }
2638
2639 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2640 let Some(model) = self.configured_model.as_ref() else {
2641 return TotalTokenUsage::default();
2642 };
2643
2644 let max = model.model.max_token_count();
2645
2646 let index = self
2647 .messages
2648 .iter()
2649 .position(|msg| msg.id == message_id)
2650 .unwrap_or(0);
2651
2652 if index == 0 {
2653 return TotalTokenUsage { total: 0, max };
2654 }
2655
2656 let token_usage = &self
2657 .request_token_usage
2658 .get(index - 1)
2659 .cloned()
2660 .unwrap_or_default();
2661
2662 TotalTokenUsage {
2663 total: token_usage.total_tokens() as usize,
2664 max,
2665 }
2666 }
2667
2668 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2669 let model = self.configured_model.as_ref()?;
2670
2671 let max = model.model.max_token_count();
2672
2673 if let Some(exceeded_error) = &self.exceeded_window_error {
2674 if model.model.id() == exceeded_error.model_id {
2675 return Some(TotalTokenUsage {
2676 total: exceeded_error.token_count,
2677 max,
2678 });
2679 }
2680 }
2681
2682 let total = self
2683 .token_usage_at_last_message()
2684 .unwrap_or_default()
2685 .total_tokens() as usize;
2686
2687 Some(TotalTokenUsage { total, max })
2688 }
2689
2690 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2691 self.request_token_usage
2692 .get(self.messages.len().saturating_sub(1))
2693 .or_else(|| self.request_token_usage.last())
2694 .cloned()
2695 }
2696
2697 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2698 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2699 self.request_token_usage
2700 .resize(self.messages.len(), placeholder);
2701
2702 if let Some(last) = self.request_token_usage.last_mut() {
2703 *last = token_usage;
2704 }
2705 }
2706
2707 pub fn deny_tool_use(
2708 &mut self,
2709 tool_use_id: LanguageModelToolUseId,
2710 tool_name: Arc<str>,
2711 window: Option<AnyWindowHandle>,
2712 cx: &mut Context<Self>,
2713 ) {
2714 let err = Err(anyhow::anyhow!(
2715 "Permission to run tool action denied by user"
2716 ));
2717
2718 self.tool_use.insert_tool_output(
2719 tool_use_id.clone(),
2720 tool_name,
2721 err,
2722 self.configured_model.as_ref(),
2723 );
2724 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2725 }
2726}
2727
2728#[derive(Debug, Clone, Error)]
2729pub enum ThreadError {
2730 #[error("Payment required")]
2731 PaymentRequired,
2732 #[error("Model request limit reached")]
2733 ModelRequestLimitReached { plan: Plan },
2734 #[error("Message {header}: {message}")]
2735 Message {
2736 header: SharedString,
2737 message: SharedString,
2738 },
2739}
2740
2741#[derive(Debug, Clone)]
2742pub enum ThreadEvent {
2743 ShowError(ThreadError),
2744 StreamedCompletion,
2745 ReceivedTextChunk,
2746 NewRequest,
2747 StreamedAssistantText(MessageId, String),
2748 StreamedAssistantThinking(MessageId, String),
2749 StreamedToolUse {
2750 tool_use_id: LanguageModelToolUseId,
2751 ui_text: Arc<str>,
2752 input: serde_json::Value,
2753 },
2754 MissingToolUse {
2755 tool_use_id: LanguageModelToolUseId,
2756 ui_text: Arc<str>,
2757 },
2758 InvalidToolInput {
2759 tool_use_id: LanguageModelToolUseId,
2760 ui_text: Arc<str>,
2761 invalid_input_json: Arc<str>,
2762 },
2763 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2764 MessageAdded(MessageId),
2765 MessageEdited(MessageId),
2766 MessageDeleted(MessageId),
2767 SummaryGenerated,
2768 SummaryChanged,
2769 UsePendingTools {
2770 tool_uses: Vec<PendingToolUse>,
2771 },
2772 ToolFinished {
2773 #[allow(unused)]
2774 tool_use_id: LanguageModelToolUseId,
2775 /// The pending tool use that corresponds to this tool.
2776 pending_tool_use: Option<PendingToolUse>,
2777 },
2778 CheckpointChanged,
2779 ToolConfirmationNeeded,
2780 CancelEditing,
2781 CompletionCanceled,
2782}
2783
2784impl EventEmitter<ThreadEvent> for Thread {}
2785
2786struct PendingCompletion {
2787 id: usize,
2788 queue_state: QueueState,
2789 _task: Task<()>,
2790}
2791
2792#[cfg(test)]
2793mod tests {
2794 use super::*;
2795 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2796 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2797 use assistant_tool::ToolRegistry;
2798 use editor::EditorSettings;
2799 use gpui::TestAppContext;
2800 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2801 use project::{FakeFs, Project};
2802 use prompt_store::PromptBuilder;
2803 use serde_json::json;
2804 use settings::{Settings, SettingsStore};
2805 use std::sync::Arc;
2806 use theme::ThemeSettings;
2807 use util::path;
2808 use workspace::Workspace;
2809
2810 #[gpui::test]
2811 async fn test_message_with_context(cx: &mut TestAppContext) {
2812 init_test_settings(cx);
2813
2814 let project = create_test_project(
2815 cx,
2816 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2817 )
2818 .await;
2819
2820 let (_workspace, _thread_store, thread, context_store, model) =
2821 setup_test_environment(cx, project.clone()).await;
2822
2823 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2824 .await
2825 .unwrap();
2826
2827 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2828 let loaded_context = cx
2829 .update(|cx| load_context(vec![context], &project, &None, cx))
2830 .await;
2831
2832 // Insert user message with context
2833 let message_id = thread.update(cx, |thread, cx| {
2834 thread.insert_user_message(
2835 "Please explain this code",
2836 loaded_context,
2837 None,
2838 Vec::new(),
2839 cx,
2840 )
2841 });
2842
2843 // Check content and context in message object
2844 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2845
2846 // Use different path format strings based on platform for the test
2847 #[cfg(windows)]
2848 let path_part = r"test\code.rs";
2849 #[cfg(not(windows))]
2850 let path_part = "test/code.rs";
2851
2852 let expected_context = format!(
2853 r#"
2854<context>
2855The following items were attached by the user. They are up-to-date and don't need to be re-read.
2856
2857<files>
2858```rs {path_part}
2859fn main() {{
2860 println!("Hello, world!");
2861}}
2862```
2863</files>
2864</context>
2865"#
2866 );
2867
2868 assert_eq!(message.role, Role::User);
2869 assert_eq!(message.segments.len(), 1);
2870 assert_eq!(
2871 message.segments[0],
2872 MessageSegment::Text("Please explain this code".to_string())
2873 );
2874 assert_eq!(message.loaded_context.text, expected_context);
2875
2876 // Check message in request
2877 let request = thread.update(cx, |thread, cx| {
2878 thread.to_completion_request(model.clone(), cx)
2879 });
2880
2881 assert_eq!(request.messages.len(), 2);
2882 let expected_full_message = format!("{}Please explain this code", expected_context);
2883 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2884 }
2885
2886 #[gpui::test]
2887 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2888 init_test_settings(cx);
2889
2890 let project = create_test_project(
2891 cx,
2892 json!({
2893 "file1.rs": "fn function1() {}\n",
2894 "file2.rs": "fn function2() {}\n",
2895 "file3.rs": "fn function3() {}\n",
2896 "file4.rs": "fn function4() {}\n",
2897 }),
2898 )
2899 .await;
2900
2901 let (_, _thread_store, thread, context_store, model) =
2902 setup_test_environment(cx, project.clone()).await;
2903
2904 // First message with context 1
2905 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2906 .await
2907 .unwrap();
2908 let new_contexts = context_store.update(cx, |store, cx| {
2909 store.new_context_for_thread(thread.read(cx), None)
2910 });
2911 assert_eq!(new_contexts.len(), 1);
2912 let loaded_context = cx
2913 .update(|cx| load_context(new_contexts, &project, &None, cx))
2914 .await;
2915 let message1_id = thread.update(cx, |thread, cx| {
2916 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2917 });
2918
2919 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2920 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2921 .await
2922 .unwrap();
2923 let new_contexts = context_store.update(cx, |store, cx| {
2924 store.new_context_for_thread(thread.read(cx), None)
2925 });
2926 assert_eq!(new_contexts.len(), 1);
2927 let loaded_context = cx
2928 .update(|cx| load_context(new_contexts, &project, &None, cx))
2929 .await;
2930 let message2_id = thread.update(cx, |thread, cx| {
2931 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2932 });
2933
2934 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2935 //
2936 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2937 .await
2938 .unwrap();
2939 let new_contexts = context_store.update(cx, |store, cx| {
2940 store.new_context_for_thread(thread.read(cx), None)
2941 });
2942 assert_eq!(new_contexts.len(), 1);
2943 let loaded_context = cx
2944 .update(|cx| load_context(new_contexts, &project, &None, cx))
2945 .await;
2946 let message3_id = thread.update(cx, |thread, cx| {
2947 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2948 });
2949
2950 // Check what contexts are included in each message
2951 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2952 (
2953 thread.message(message1_id).unwrap().clone(),
2954 thread.message(message2_id).unwrap().clone(),
2955 thread.message(message3_id).unwrap().clone(),
2956 )
2957 });
2958
2959 // First message should include context 1
2960 assert!(message1.loaded_context.text.contains("file1.rs"));
2961
2962 // Second message should include only context 2 (not 1)
2963 assert!(!message2.loaded_context.text.contains("file1.rs"));
2964 assert!(message2.loaded_context.text.contains("file2.rs"));
2965
2966 // Third message should include only context 3 (not 1 or 2)
2967 assert!(!message3.loaded_context.text.contains("file1.rs"));
2968 assert!(!message3.loaded_context.text.contains("file2.rs"));
2969 assert!(message3.loaded_context.text.contains("file3.rs"));
2970
2971 // Check entire request to make sure all contexts are properly included
2972 let request = thread.update(cx, |thread, cx| {
2973 thread.to_completion_request(model.clone(), cx)
2974 });
2975
2976 // The request should contain all 3 messages
2977 assert_eq!(request.messages.len(), 4);
2978
2979 // Check that the contexts are properly formatted in each message
2980 assert!(request.messages[1].string_contents().contains("file1.rs"));
2981 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2982 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2983
2984 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2985 assert!(request.messages[2].string_contents().contains("file2.rs"));
2986 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2987
2988 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2989 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2990 assert!(request.messages[3].string_contents().contains("file3.rs"));
2991
2992 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2993 .await
2994 .unwrap();
2995 let new_contexts = context_store.update(cx, |store, cx| {
2996 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2997 });
2998 assert_eq!(new_contexts.len(), 3);
2999 let loaded_context = cx
3000 .update(|cx| load_context(new_contexts, &project, &None, cx))
3001 .await
3002 .loaded_context;
3003
3004 assert!(!loaded_context.text.contains("file1.rs"));
3005 assert!(loaded_context.text.contains("file2.rs"));
3006 assert!(loaded_context.text.contains("file3.rs"));
3007 assert!(loaded_context.text.contains("file4.rs"));
3008
3009 let new_contexts = context_store.update(cx, |store, cx| {
3010 // Remove file4.rs
3011 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3012 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3013 });
3014 assert_eq!(new_contexts.len(), 2);
3015 let loaded_context = cx
3016 .update(|cx| load_context(new_contexts, &project, &None, cx))
3017 .await
3018 .loaded_context;
3019
3020 assert!(!loaded_context.text.contains("file1.rs"));
3021 assert!(loaded_context.text.contains("file2.rs"));
3022 assert!(loaded_context.text.contains("file3.rs"));
3023 assert!(!loaded_context.text.contains("file4.rs"));
3024
3025 let new_contexts = context_store.update(cx, |store, cx| {
3026 // Remove file3.rs
3027 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3028 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3029 });
3030 assert_eq!(new_contexts.len(), 1);
3031 let loaded_context = cx
3032 .update(|cx| load_context(new_contexts, &project, &None, cx))
3033 .await
3034 .loaded_context;
3035
3036 assert!(!loaded_context.text.contains("file1.rs"));
3037 assert!(loaded_context.text.contains("file2.rs"));
3038 assert!(!loaded_context.text.contains("file3.rs"));
3039 assert!(!loaded_context.text.contains("file4.rs"));
3040 }
3041
3042 #[gpui::test]
3043 async fn test_message_without_files(cx: &mut TestAppContext) {
3044 init_test_settings(cx);
3045
3046 let project = create_test_project(
3047 cx,
3048 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3049 )
3050 .await;
3051
3052 let (_, _thread_store, thread, _context_store, model) =
3053 setup_test_environment(cx, project.clone()).await;
3054
3055 // Insert user message without any context (empty context vector)
3056 let message_id = thread.update(cx, |thread, cx| {
3057 thread.insert_user_message(
3058 "What is the best way to learn Rust?",
3059 ContextLoadResult::default(),
3060 None,
3061 Vec::new(),
3062 cx,
3063 )
3064 });
3065
3066 // Check content and context in message object
3067 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3068
3069 // Context should be empty when no files are included
3070 assert_eq!(message.role, Role::User);
3071 assert_eq!(message.segments.len(), 1);
3072 assert_eq!(
3073 message.segments[0],
3074 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3075 );
3076 assert_eq!(message.loaded_context.text, "");
3077
3078 // Check message in request
3079 let request = thread.update(cx, |thread, cx| {
3080 thread.to_completion_request(model.clone(), cx)
3081 });
3082
3083 assert_eq!(request.messages.len(), 2);
3084 assert_eq!(
3085 request.messages[1].string_contents(),
3086 "What is the best way to learn Rust?"
3087 );
3088
3089 // Add second message, also without context
3090 let message2_id = thread.update(cx, |thread, cx| {
3091 thread.insert_user_message(
3092 "Are there any good books?",
3093 ContextLoadResult::default(),
3094 None,
3095 Vec::new(),
3096 cx,
3097 )
3098 });
3099
3100 let message2 =
3101 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3102 assert_eq!(message2.loaded_context.text, "");
3103
3104 // Check that both messages appear in the request
3105 let request = thread.update(cx, |thread, cx| {
3106 thread.to_completion_request(model.clone(), cx)
3107 });
3108
3109 assert_eq!(request.messages.len(), 3);
3110 assert_eq!(
3111 request.messages[1].string_contents(),
3112 "What is the best way to learn Rust?"
3113 );
3114 assert_eq!(
3115 request.messages[2].string_contents(),
3116 "Are there any good books?"
3117 );
3118 }
3119
3120 #[gpui::test]
3121 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3122 init_test_settings(cx);
3123
3124 let project = create_test_project(
3125 cx,
3126 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3127 )
3128 .await;
3129
3130 let (_workspace, _thread_store, thread, context_store, model) =
3131 setup_test_environment(cx, project.clone()).await;
3132
3133 // Open buffer and add it to context
3134 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3135 .await
3136 .unwrap();
3137
3138 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3139 let loaded_context = cx
3140 .update(|cx| load_context(vec![context], &project, &None, cx))
3141 .await;
3142
3143 // Insert user message with the buffer as context
3144 thread.update(cx, |thread, cx| {
3145 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3146 });
3147
3148 // Create a request and check that it doesn't have a stale buffer warning yet
3149 let initial_request = thread.update(cx, |thread, cx| {
3150 thread.to_completion_request(model.clone(), cx)
3151 });
3152
3153 // Make sure we don't have a stale file warning yet
3154 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3155 msg.string_contents()
3156 .contains("These files changed since last read:")
3157 });
3158 assert!(
3159 !has_stale_warning,
3160 "Should not have stale buffer warning before buffer is modified"
3161 );
3162
3163 // Modify the buffer
3164 buffer.update(cx, |buffer, cx| {
3165 // Find a position at the end of line 1
3166 buffer.edit(
3167 [(1..1, "\n println!(\"Added a new line\");\n")],
3168 None,
3169 cx,
3170 );
3171 });
3172
3173 // Insert another user message without context
3174 thread.update(cx, |thread, cx| {
3175 thread.insert_user_message(
3176 "What does the code do now?",
3177 ContextLoadResult::default(),
3178 None,
3179 Vec::new(),
3180 cx,
3181 )
3182 });
3183
3184 // Create a new request and check for the stale buffer warning
3185 let new_request = thread.update(cx, |thread, cx| {
3186 thread.to_completion_request(model.clone(), cx)
3187 });
3188
3189 // We should have a stale file warning as the last message
3190 let last_message = new_request
3191 .messages
3192 .last()
3193 .expect("Request should have messages");
3194
3195 // The last message should be the stale buffer notification
3196 assert_eq!(last_message.role, Role::User);
3197
3198 // Check the exact content of the message
3199 let expected_content = "These files changed since last read:\n- code.rs\n";
3200 assert_eq!(
3201 last_message.string_contents(),
3202 expected_content,
3203 "Last message should be exactly the stale buffer notification"
3204 );
3205 }
3206
3207 #[gpui::test]
3208 async fn test_temperature_setting(cx: &mut TestAppContext) {
3209 init_test_settings(cx);
3210
3211 let project = create_test_project(
3212 cx,
3213 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3214 )
3215 .await;
3216
3217 let (_workspace, _thread_store, thread, _context_store, model) =
3218 setup_test_environment(cx, project.clone()).await;
3219
3220 // Both model and provider
3221 cx.update(|cx| {
3222 AssistantSettings::override_global(
3223 AssistantSettings {
3224 model_parameters: vec![LanguageModelParameters {
3225 provider: Some(model.provider_id().0.to_string().into()),
3226 model: Some(model.id().0.clone()),
3227 temperature: Some(0.66),
3228 }],
3229 ..AssistantSettings::get_global(cx).clone()
3230 },
3231 cx,
3232 );
3233 });
3234
3235 let request = thread.update(cx, |thread, cx| {
3236 thread.to_completion_request(model.clone(), cx)
3237 });
3238 assert_eq!(request.temperature, Some(0.66));
3239
3240 // Only model
3241 cx.update(|cx| {
3242 AssistantSettings::override_global(
3243 AssistantSettings {
3244 model_parameters: vec![LanguageModelParameters {
3245 provider: None,
3246 model: Some(model.id().0.clone()),
3247 temperature: Some(0.66),
3248 }],
3249 ..AssistantSettings::get_global(cx).clone()
3250 },
3251 cx,
3252 );
3253 });
3254
3255 let request = thread.update(cx, |thread, cx| {
3256 thread.to_completion_request(model.clone(), cx)
3257 });
3258 assert_eq!(request.temperature, Some(0.66));
3259
3260 // Only provider
3261 cx.update(|cx| {
3262 AssistantSettings::override_global(
3263 AssistantSettings {
3264 model_parameters: vec![LanguageModelParameters {
3265 provider: Some(model.provider_id().0.to_string().into()),
3266 model: None,
3267 temperature: Some(0.66),
3268 }],
3269 ..AssistantSettings::get_global(cx).clone()
3270 },
3271 cx,
3272 );
3273 });
3274
3275 let request = thread.update(cx, |thread, cx| {
3276 thread.to_completion_request(model.clone(), cx)
3277 });
3278 assert_eq!(request.temperature, Some(0.66));
3279
3280 // Same model name, different provider
3281 cx.update(|cx| {
3282 AssistantSettings::override_global(
3283 AssistantSettings {
3284 model_parameters: vec![LanguageModelParameters {
3285 provider: Some("anthropic".into()),
3286 model: Some(model.id().0.clone()),
3287 temperature: Some(0.66),
3288 }],
3289 ..AssistantSettings::get_global(cx).clone()
3290 },
3291 cx,
3292 );
3293 });
3294
3295 let request = thread.update(cx, |thread, cx| {
3296 thread.to_completion_request(model.clone(), cx)
3297 });
3298 assert_eq!(request.temperature, None);
3299 }
3300
3301 #[gpui::test]
3302 async fn test_thread_summary(cx: &mut TestAppContext) {
3303 init_test_settings(cx);
3304
3305 let project = create_test_project(cx, json!({})).await;
3306
3307 let (_, _thread_store, thread, _context_store, model) =
3308 setup_test_environment(cx, project.clone()).await;
3309
3310 // Initial state should be pending
3311 thread.read_with(cx, |thread, _| {
3312 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3313 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3314 });
3315
3316 // Manually setting the summary should not be allowed in this state
3317 thread.update(cx, |thread, cx| {
3318 thread.set_summary("This should not work", cx);
3319 });
3320
3321 thread.read_with(cx, |thread, _| {
3322 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3323 });
3324
3325 // Send a message
3326 thread.update(cx, |thread, cx| {
3327 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3328 thread.send_to_model(model.clone(), None, cx);
3329 });
3330
3331 let fake_model = model.as_fake();
3332 simulate_successful_response(&fake_model, cx);
3333
3334 // Should start generating summary when there are >= 2 messages
3335 thread.read_with(cx, |thread, _| {
3336 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3337 });
3338
3339 // Should not be able to set the summary while generating
3340 thread.update(cx, |thread, cx| {
3341 thread.set_summary("This should not work either", cx);
3342 });
3343
3344 thread.read_with(cx, |thread, _| {
3345 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3346 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3347 });
3348
3349 cx.run_until_parked();
3350 fake_model.stream_last_completion_response("Brief".into());
3351 fake_model.stream_last_completion_response(" Introduction".into());
3352 fake_model.end_last_completion_stream();
3353 cx.run_until_parked();
3354
3355 // Summary should be set
3356 thread.read_with(cx, |thread, _| {
3357 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3358 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3359 });
3360
3361 // Now we should be able to set a summary
3362 thread.update(cx, |thread, cx| {
3363 thread.set_summary("Brief Intro", cx);
3364 });
3365
3366 thread.read_with(cx, |thread, _| {
3367 assert_eq!(thread.summary().or_default(), "Brief Intro");
3368 });
3369
3370 // Test setting an empty summary (should default to DEFAULT)
3371 thread.update(cx, |thread, cx| {
3372 thread.set_summary("", cx);
3373 });
3374
3375 thread.read_with(cx, |thread, _| {
3376 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3377 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3378 });
3379 }
3380
3381 #[gpui::test]
3382 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3383 init_test_settings(cx);
3384
3385 let project = create_test_project(cx, json!({})).await;
3386
3387 let (_, _thread_store, thread, _context_store, model) =
3388 setup_test_environment(cx, project.clone()).await;
3389
3390 test_summarize_error(&model, &thread, cx);
3391
3392 // Now we should be able to set a summary
3393 thread.update(cx, |thread, cx| {
3394 thread.set_summary("Brief Intro", cx);
3395 });
3396
3397 thread.read_with(cx, |thread, _| {
3398 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3399 assert_eq!(thread.summary().or_default(), "Brief Intro");
3400 });
3401 }
3402
3403 #[gpui::test]
3404 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3405 init_test_settings(cx);
3406
3407 let project = create_test_project(cx, json!({})).await;
3408
3409 let (_, _thread_store, thread, _context_store, model) =
3410 setup_test_environment(cx, project.clone()).await;
3411
3412 test_summarize_error(&model, &thread, cx);
3413
3414 // Sending another message should not trigger another summarize request
3415 thread.update(cx, |thread, cx| {
3416 thread.insert_user_message(
3417 "How are you?",
3418 ContextLoadResult::default(),
3419 None,
3420 vec![],
3421 cx,
3422 );
3423 thread.send_to_model(model.clone(), None, cx);
3424 });
3425
3426 let fake_model = model.as_fake();
3427 simulate_successful_response(&fake_model, cx);
3428
3429 thread.read_with(cx, |thread, _| {
3430 // State is still Error, not Generating
3431 assert!(matches!(thread.summary(), ThreadSummary::Error));
3432 });
3433
3434 // But the summarize request can be invoked manually
3435 thread.update(cx, |thread, cx| {
3436 thread.summarize(cx);
3437 });
3438
3439 thread.read_with(cx, |thread, _| {
3440 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3441 });
3442
3443 cx.run_until_parked();
3444 fake_model.stream_last_completion_response("A successful summary".into());
3445 fake_model.end_last_completion_stream();
3446 cx.run_until_parked();
3447
3448 thread.read_with(cx, |thread, _| {
3449 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3450 assert_eq!(thread.summary().or_default(), "A successful summary");
3451 });
3452 }
3453
3454 fn test_summarize_error(
3455 model: &Arc<dyn LanguageModel>,
3456 thread: &Entity<Thread>,
3457 cx: &mut TestAppContext,
3458 ) {
3459 thread.update(cx, |thread, cx| {
3460 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3461 thread.send_to_model(model.clone(), None, cx);
3462 });
3463
3464 let fake_model = model.as_fake();
3465 simulate_successful_response(&fake_model, cx);
3466
3467 thread.read_with(cx, |thread, _| {
3468 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3469 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3470 });
3471
3472 // Simulate summary request ending
3473 cx.run_until_parked();
3474 fake_model.end_last_completion_stream();
3475 cx.run_until_parked();
3476
3477 // State is set to Error and default message
3478 thread.read_with(cx, |thread, _| {
3479 assert!(matches!(thread.summary(), ThreadSummary::Error));
3480 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3481 });
3482 }
3483
3484 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3485 cx.run_until_parked();
3486 fake_model.stream_last_completion_response("Assistant response".into());
3487 fake_model.end_last_completion_stream();
3488 cx.run_until_parked();
3489 }
3490
3491 fn init_test_settings(cx: &mut TestAppContext) {
3492 cx.update(|cx| {
3493 let settings_store = SettingsStore::test(cx);
3494 cx.set_global(settings_store);
3495 language::init(cx);
3496 Project::init_settings(cx);
3497 AssistantSettings::register(cx);
3498 prompt_store::init(cx);
3499 thread_store::init(cx);
3500 workspace::init_settings(cx);
3501 language_model::init_settings(cx);
3502 ThemeSettings::register(cx);
3503 EditorSettings::register(cx);
3504 ToolRegistry::default_global(cx);
3505 });
3506 }
3507
3508 // Helper to create a test project with test files
3509 async fn create_test_project(
3510 cx: &mut TestAppContext,
3511 files: serde_json::Value,
3512 ) -> Entity<Project> {
3513 let fs = FakeFs::new(cx.executor());
3514 fs.insert_tree(path!("/test"), files).await;
3515 Project::test(fs, [path!("/test").as_ref()], cx).await
3516 }
3517
3518 async fn setup_test_environment(
3519 cx: &mut TestAppContext,
3520 project: Entity<Project>,
3521 ) -> (
3522 Entity<Workspace>,
3523 Entity<ThreadStore>,
3524 Entity<Thread>,
3525 Entity<ContextStore>,
3526 Arc<dyn LanguageModel>,
3527 ) {
3528 let (workspace, cx) =
3529 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3530
3531 let thread_store = cx
3532 .update(|_, cx| {
3533 ThreadStore::load(
3534 project.clone(),
3535 cx.new(|_| ToolWorkingSet::default()),
3536 None,
3537 Arc::new(PromptBuilder::new(None).unwrap()),
3538 cx,
3539 )
3540 })
3541 .await
3542 .unwrap();
3543
3544 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3545 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3546
3547 let provider = Arc::new(FakeLanguageModelProvider);
3548 let model = provider.test_model();
3549 let model: Arc<dyn LanguageModel> = Arc::new(model);
3550
3551 cx.update(|_, cx| {
3552 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3553 registry.set_default_model(
3554 Some(ConfiguredModel {
3555 provider: provider.clone(),
3556 model: model.clone(),
3557 }),
3558 cx,
3559 );
3560 registry.set_thread_summary_model(
3561 Some(ConfiguredModel {
3562 provider,
3563 model: model.clone(),
3564 }),
3565 cx,
3566 );
3567 })
3568 });
3569
3570 (workspace, thread_store, thread, context_store, model)
3571 }
3572
3573 async fn add_file_to_context(
3574 project: &Entity<Project>,
3575 context_store: &Entity<ContextStore>,
3576 path: &str,
3577 cx: &mut TestAppContext,
3578 ) -> Result<Entity<language::Buffer>> {
3579 let buffer_path = project
3580 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3581 .unwrap();
3582
3583 let buffer = project
3584 .update(cx, |project, cx| {
3585 project.open_buffer(buffer_path.clone(), cx)
3586 })
3587 .await
3588 .unwrap();
3589
3590 context_store.update(cx, |context_store, cx| {
3591 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3592 });
3593
3594 Ok(buffer)
3595 }
3596}