1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone, Debug)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(text)
885 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
886 Some(text)
887 }
888 LanguageModelToolResultContent::Image(_) => {
889 // TODO: We should display image
890 None
891 }
892 }
893 }
894
895 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
896 self.tool_use.tool_result_card(id).cloned()
897 }
898
899 /// Return tools that are both enabled and supported by the model
900 pub fn available_tools(
901 &self,
902 cx: &App,
903 model: Arc<dyn LanguageModel>,
904 ) -> Vec<LanguageModelRequestTool> {
905 if model.supports_tools() {
906 self.tools()
907 .read(cx)
908 .enabled_tools(cx)
909 .into_iter()
910 .filter_map(|tool| {
911 // Skip tools that cannot be supported
912 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
913 Some(LanguageModelRequestTool {
914 name: tool.name(),
915 description: tool.description(),
916 input_schema,
917 })
918 })
919 .collect()
920 } else {
921 Vec::default()
922 }
923 }
924
925 pub fn insert_user_message(
926 &mut self,
927 text: impl Into<String>,
928 loaded_context: ContextLoadResult,
929 git_checkpoint: Option<GitStoreCheckpoint>,
930 creases: Vec<MessageCrease>,
931 cx: &mut Context<Self>,
932 ) -> MessageId {
933 if !loaded_context.referenced_buffers.is_empty() {
934 self.action_log.update(cx, |log, cx| {
935 for buffer in loaded_context.referenced_buffers {
936 log.buffer_read(buffer, cx);
937 }
938 });
939 }
940
941 let message_id = self.insert_message(
942 Role::User,
943 vec![MessageSegment::Text(text.into())],
944 loaded_context.loaded_context,
945 creases,
946 cx,
947 );
948
949 if let Some(git_checkpoint) = git_checkpoint {
950 self.pending_checkpoint = Some(ThreadCheckpoint {
951 message_id,
952 git_checkpoint,
953 });
954 }
955
956 self.auto_capture_telemetry(cx);
957
958 message_id
959 }
960
961 pub fn insert_assistant_message(
962 &mut self,
963 segments: Vec<MessageSegment>,
964 cx: &mut Context<Self>,
965 ) -> MessageId {
966 self.insert_message(
967 Role::Assistant,
968 segments,
969 LoadedContext::default(),
970 Vec::new(),
971 cx,
972 )
973 }
974
975 pub fn insert_message(
976 &mut self,
977 role: Role,
978 segments: Vec<MessageSegment>,
979 loaded_context: LoadedContext,
980 creases: Vec<MessageCrease>,
981 cx: &mut Context<Self>,
982 ) -> MessageId {
983 let id = self.next_message_id.post_inc();
984 self.messages.push(Message {
985 id,
986 role,
987 segments,
988 loaded_context,
989 creases,
990 });
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageAdded(id));
993 id
994 }
995
996 pub fn edit_message(
997 &mut self,
998 id: MessageId,
999 new_role: Role,
1000 new_segments: Vec<MessageSegment>,
1001 loaded_context: Option<LoadedContext>,
1002 checkpoint: Option<GitStoreCheckpoint>,
1003 cx: &mut Context<Self>,
1004 ) -> bool {
1005 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1006 return false;
1007 };
1008 message.role = new_role;
1009 message.segments = new_segments;
1010 if let Some(context) = loaded_context {
1011 message.loaded_context = context;
1012 }
1013 if let Some(git_checkpoint) = checkpoint {
1014 self.checkpoints_by_message.insert(
1015 id,
1016 ThreadCheckpoint {
1017 message_id: id,
1018 git_checkpoint,
1019 },
1020 );
1021 }
1022 self.touch_updated_at();
1023 cx.emit(ThreadEvent::MessageEdited(id));
1024 true
1025 }
1026
1027 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1028 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1029 return false;
1030 };
1031 self.messages.remove(index);
1032 self.touch_updated_at();
1033 cx.emit(ThreadEvent::MessageDeleted(id));
1034 true
1035 }
1036
1037 /// Returns the representation of this [`Thread`] in a textual form.
1038 ///
1039 /// This is the representation we use when attaching a thread as context to another thread.
1040 pub fn text(&self) -> String {
1041 let mut text = String::new();
1042
1043 for message in &self.messages {
1044 text.push_str(match message.role {
1045 language_model::Role::User => "User:",
1046 language_model::Role::Assistant => "Agent:",
1047 language_model::Role::System => "System:",
1048 });
1049 text.push('\n');
1050
1051 for segment in &message.segments {
1052 match segment {
1053 MessageSegment::Text(content) => text.push_str(content),
1054 MessageSegment::Thinking { text: content, .. } => {
1055 text.push_str(&format!("<think>{}</think>", content))
1056 }
1057 MessageSegment::RedactedThinking(_) => {}
1058 }
1059 }
1060 text.push('\n');
1061 }
1062
1063 text
1064 }
1065
1066 /// Serializes this thread into a format for storage or telemetry.
1067 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1068 let initial_project_snapshot = self.initial_project_snapshot.clone();
1069 cx.spawn(async move |this, cx| {
1070 let initial_project_snapshot = initial_project_snapshot.await;
1071 this.read_with(cx, |this, cx| SerializedThread {
1072 version: SerializedThread::VERSION.to_string(),
1073 summary: this.summary().or_default(),
1074 updated_at: this.updated_at(),
1075 messages: this
1076 .messages()
1077 .map(|message| SerializedMessage {
1078 id: message.id,
1079 role: message.role,
1080 segments: message
1081 .segments
1082 .iter()
1083 .map(|segment| match segment {
1084 MessageSegment::Text(text) => {
1085 SerializedMessageSegment::Text { text: text.clone() }
1086 }
1087 MessageSegment::Thinking { text, signature } => {
1088 SerializedMessageSegment::Thinking {
1089 text: text.clone(),
1090 signature: signature.clone(),
1091 }
1092 }
1093 MessageSegment::RedactedThinking(data) => {
1094 SerializedMessageSegment::RedactedThinking {
1095 data: data.clone(),
1096 }
1097 }
1098 })
1099 .collect(),
1100 tool_uses: this
1101 .tool_uses_for_message(message.id, cx)
1102 .into_iter()
1103 .map(|tool_use| SerializedToolUse {
1104 id: tool_use.id,
1105 name: tool_use.name,
1106 input: tool_use.input,
1107 })
1108 .collect(),
1109 tool_results: this
1110 .tool_results_for_message(message.id)
1111 .into_iter()
1112 .map(|tool_result| SerializedToolResult {
1113 tool_use_id: tool_result.tool_use_id.clone(),
1114 is_error: tool_result.is_error,
1115 content: tool_result.content.clone(),
1116 output: tool_result.output.clone(),
1117 })
1118 .collect(),
1119 context: message.loaded_context.text.clone(),
1120 creases: message
1121 .creases
1122 .iter()
1123 .map(|crease| SerializedCrease {
1124 start: crease.range.start,
1125 end: crease.range.end,
1126 icon_path: crease.metadata.icon_path.clone(),
1127 label: crease.metadata.label.clone(),
1128 })
1129 .collect(),
1130 })
1131 .collect(),
1132 initial_project_snapshot,
1133 cumulative_token_usage: this.cumulative_token_usage,
1134 request_token_usage: this.request_token_usage.clone(),
1135 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1136 exceeded_window_error: this.exceeded_window_error.clone(),
1137 model: this
1138 .configured_model
1139 .as_ref()
1140 .map(|model| SerializedLanguageModel {
1141 provider: model.provider.id().0.to_string(),
1142 model: model.model.id().0.to_string(),
1143 }),
1144 completion_mode: Some(this.completion_mode),
1145 })
1146 })
1147 }
1148
1149 pub fn remaining_turns(&self) -> u32 {
1150 self.remaining_turns
1151 }
1152
1153 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1154 self.remaining_turns = remaining_turns;
1155 }
1156
1157 pub fn send_to_model(
1158 &mut self,
1159 model: Arc<dyn LanguageModel>,
1160 window: Option<AnyWindowHandle>,
1161 cx: &mut Context<Self>,
1162 ) {
1163 if self.remaining_turns == 0 {
1164 return;
1165 }
1166
1167 self.remaining_turns -= 1;
1168
1169 let request = self.to_completion_request(model.clone(), cx);
1170
1171 self.stream_completion(request, model, window, cx);
1172 }
1173
1174 pub fn used_tools_since_last_user_message(&self) -> bool {
1175 for message in self.messages.iter().rev() {
1176 if self.tool_use.message_has_tool_results(message.id) {
1177 return true;
1178 } else if message.role == Role::User {
1179 return false;
1180 }
1181 }
1182
1183 false
1184 }
1185
1186 pub fn to_completion_request(
1187 &self,
1188 model: Arc<dyn LanguageModel>,
1189 cx: &mut Context<Self>,
1190 ) -> LanguageModelRequest {
1191 let mut request = LanguageModelRequest {
1192 thread_id: Some(self.id.to_string()),
1193 prompt_id: Some(self.last_prompt_id.to_string()),
1194 mode: None,
1195 messages: vec![],
1196 tools: Vec::new(),
1197 tool_choice: None,
1198 stop: Vec::new(),
1199 temperature: AssistantSettings::temperature_for_model(&model, cx),
1200 };
1201
1202 let available_tools = self.available_tools(cx, model.clone());
1203 let available_tool_names = available_tools
1204 .iter()
1205 .map(|tool| tool.name.clone())
1206 .collect();
1207
1208 let model_context = &ModelContext {
1209 available_tools: available_tool_names,
1210 };
1211
1212 if let Some(project_context) = self.project_context.borrow().as_ref() {
1213 match self
1214 .prompt_builder
1215 .generate_assistant_system_prompt(project_context, model_context)
1216 {
1217 Err(err) => {
1218 let message = format!("{err:?}").into();
1219 log::error!("{message}");
1220 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1221 header: "Error generating system prompt".into(),
1222 message,
1223 }));
1224 }
1225 Ok(system_prompt) => {
1226 request.messages.push(LanguageModelRequestMessage {
1227 role: Role::System,
1228 content: vec![MessageContent::Text(system_prompt)],
1229 cache: true,
1230 });
1231 }
1232 }
1233 } else {
1234 let message = "Context for system prompt unexpectedly not ready.".into();
1235 log::error!("{message}");
1236 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1237 header: "Error generating system prompt".into(),
1238 message,
1239 }));
1240 }
1241
1242 let mut message_ix_to_cache = None;
1243 for message in &self.messages {
1244 let mut request_message = LanguageModelRequestMessage {
1245 role: message.role,
1246 content: Vec::new(),
1247 cache: false,
1248 };
1249
1250 message
1251 .loaded_context
1252 .add_to_request_message(&mut request_message);
1253
1254 for segment in &message.segments {
1255 match segment {
1256 MessageSegment::Text(text) => {
1257 if !text.is_empty() {
1258 request_message
1259 .content
1260 .push(MessageContent::Text(text.into()));
1261 }
1262 }
1263 MessageSegment::Thinking { text, signature } => {
1264 if !text.is_empty() {
1265 request_message.content.push(MessageContent::Thinking {
1266 text: text.into(),
1267 signature: signature.clone(),
1268 });
1269 }
1270 }
1271 MessageSegment::RedactedThinking(data) => {
1272 request_message
1273 .content
1274 .push(MessageContent::RedactedThinking(data.clone()));
1275 }
1276 };
1277 }
1278
1279 let mut cache_message = true;
1280 let mut tool_results_message = LanguageModelRequestMessage {
1281 role: Role::User,
1282 content: Vec::new(),
1283 cache: false,
1284 };
1285 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1286 if let Some(tool_result) = tool_result {
1287 request_message
1288 .content
1289 .push(MessageContent::ToolUse(tool_use.clone()));
1290 tool_results_message
1291 .content
1292 .push(MessageContent::ToolResult(LanguageModelToolResult {
1293 tool_use_id: tool_use.id.clone(),
1294 tool_name: tool_result.tool_name.clone(),
1295 is_error: tool_result.is_error,
1296 content: if tool_result.content.is_empty() {
1297 // Surprisingly, the API fails if we return an empty string here.
1298 // It thinks we are sending a tool use without a tool result.
1299 "<Tool returned an empty string>".into()
1300 } else {
1301 tool_result.content.clone()
1302 },
1303 output: None,
1304 }));
1305 } else {
1306 cache_message = false;
1307 log::debug!(
1308 "skipped tool use {:?} because it is still pending",
1309 tool_use
1310 );
1311 }
1312 }
1313
1314 if cache_message {
1315 message_ix_to_cache = Some(request.messages.len());
1316 }
1317 request.messages.push(request_message);
1318
1319 if !tool_results_message.content.is_empty() {
1320 if cache_message {
1321 message_ix_to_cache = Some(request.messages.len());
1322 }
1323 request.messages.push(tool_results_message);
1324 }
1325 }
1326
1327 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1328 if let Some(message_ix_to_cache) = message_ix_to_cache {
1329 request.messages[message_ix_to_cache].cache = true;
1330 }
1331
1332 self.attached_tracked_files_state(&mut request.messages, cx);
1333
1334 request.tools = available_tools;
1335 request.mode = if model.supports_max_mode() {
1336 Some(self.completion_mode.into())
1337 } else {
1338 Some(CompletionMode::Normal.into())
1339 };
1340
1341 request
1342 }
1343
1344 fn to_summarize_request(
1345 &self,
1346 model: &Arc<dyn LanguageModel>,
1347 added_user_message: String,
1348 cx: &App,
1349 ) -> LanguageModelRequest {
1350 let mut request = LanguageModelRequest {
1351 thread_id: None,
1352 prompt_id: None,
1353 mode: None,
1354 messages: vec![],
1355 tools: Vec::new(),
1356 tool_choice: None,
1357 stop: Vec::new(),
1358 temperature: AssistantSettings::temperature_for_model(model, cx),
1359 };
1360
1361 for message in &self.messages {
1362 let mut request_message = LanguageModelRequestMessage {
1363 role: message.role,
1364 content: Vec::new(),
1365 cache: false,
1366 };
1367
1368 for segment in &message.segments {
1369 match segment {
1370 MessageSegment::Text(text) => request_message
1371 .content
1372 .push(MessageContent::Text(text.clone())),
1373 MessageSegment::Thinking { .. } => {}
1374 MessageSegment::RedactedThinking(_) => {}
1375 }
1376 }
1377
1378 if request_message.content.is_empty() {
1379 continue;
1380 }
1381
1382 request.messages.push(request_message);
1383 }
1384
1385 request.messages.push(LanguageModelRequestMessage {
1386 role: Role::User,
1387 content: vec![MessageContent::Text(added_user_message)],
1388 cache: false,
1389 });
1390
1391 request
1392 }
1393
1394 fn attached_tracked_files_state(
1395 &self,
1396 messages: &mut Vec<LanguageModelRequestMessage>,
1397 cx: &App,
1398 ) {
1399 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1400
1401 let mut stale_message = String::new();
1402
1403 let action_log = self.action_log.read(cx);
1404
1405 for stale_file in action_log.stale_buffers(cx) {
1406 let Some(file) = stale_file.read(cx).file() else {
1407 continue;
1408 };
1409
1410 if stale_message.is_empty() {
1411 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1412 }
1413
1414 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1415 }
1416
1417 let mut content = Vec::with_capacity(2);
1418
1419 if !stale_message.is_empty() {
1420 content.push(stale_message.into());
1421 }
1422
1423 if !content.is_empty() {
1424 let context_message = LanguageModelRequestMessage {
1425 role: Role::User,
1426 content,
1427 cache: false,
1428 };
1429
1430 messages.push(context_message);
1431 }
1432 }
1433
1434 pub fn stream_completion(
1435 &mut self,
1436 request: LanguageModelRequest,
1437 model: Arc<dyn LanguageModel>,
1438 window: Option<AnyWindowHandle>,
1439 cx: &mut Context<Self>,
1440 ) {
1441 self.tool_use_limit_reached = false;
1442
1443 let pending_completion_id = post_inc(&mut self.completion_count);
1444 let mut request_callback_parameters = if self.request_callback.is_some() {
1445 Some((request.clone(), Vec::new()))
1446 } else {
1447 None
1448 };
1449 let prompt_id = self.last_prompt_id.clone();
1450 let tool_use_metadata = ToolUseMetadata {
1451 model: model.clone(),
1452 thread_id: self.id.clone(),
1453 prompt_id: prompt_id.clone(),
1454 };
1455
1456 self.last_received_chunk_at = Some(Instant::now());
1457
1458 let task = cx.spawn(async move |thread, cx| {
1459 let stream_completion_future = model.stream_completion(request, &cx);
1460 let initial_token_usage =
1461 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1462 let stream_completion = async {
1463 let mut events = stream_completion_future.await?;
1464
1465 let mut stop_reason = StopReason::EndTurn;
1466 let mut current_token_usage = TokenUsage::default();
1467
1468 thread
1469 .update(cx, |_thread, cx| {
1470 cx.emit(ThreadEvent::NewRequest);
1471 })
1472 .ok();
1473
1474 let mut request_assistant_message_id = None;
1475
1476 while let Some(event) = events.next().await {
1477 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1478 response_events
1479 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1480 }
1481
1482 thread.update(cx, |thread, cx| {
1483 let event = match event {
1484 Ok(event) => event,
1485 Err(LanguageModelCompletionError::BadInputJson {
1486 id,
1487 tool_name,
1488 raw_input: invalid_input_json,
1489 json_parse_error,
1490 }) => {
1491 thread.receive_invalid_tool_json(
1492 id,
1493 tool_name,
1494 invalid_input_json,
1495 json_parse_error,
1496 window,
1497 cx,
1498 );
1499 return Ok(());
1500 }
1501 Err(LanguageModelCompletionError::Other(error)) => {
1502 return Err(error);
1503 }
1504 };
1505
1506 match event {
1507 LanguageModelCompletionEvent::StartMessage { .. } => {
1508 request_assistant_message_id =
1509 Some(thread.insert_assistant_message(
1510 vec![MessageSegment::Text(String::new())],
1511 cx,
1512 ));
1513 }
1514 LanguageModelCompletionEvent::Stop(reason) => {
1515 stop_reason = reason;
1516 }
1517 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1518 thread.update_token_usage_at_last_message(token_usage);
1519 thread.cumulative_token_usage = thread.cumulative_token_usage
1520 + token_usage
1521 - current_token_usage;
1522 current_token_usage = token_usage;
1523 }
1524 LanguageModelCompletionEvent::Text(chunk) => {
1525 thread.received_chunk();
1526
1527 cx.emit(ThreadEvent::ReceivedTextChunk);
1528 if let Some(last_message) = thread.messages.last_mut() {
1529 if last_message.role == Role::Assistant
1530 && !thread.tool_use.has_tool_results(last_message.id)
1531 {
1532 last_message.push_text(&chunk);
1533 cx.emit(ThreadEvent::StreamedAssistantText(
1534 last_message.id,
1535 chunk,
1536 ));
1537 } else {
1538 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1539 // of a new Assistant response.
1540 //
1541 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1542 // will result in duplicating the text of the chunk in the rendered Markdown.
1543 request_assistant_message_id =
1544 Some(thread.insert_assistant_message(
1545 vec![MessageSegment::Text(chunk.to_string())],
1546 cx,
1547 ));
1548 };
1549 }
1550 }
1551 LanguageModelCompletionEvent::Thinking {
1552 text: chunk,
1553 signature,
1554 } => {
1555 thread.received_chunk();
1556
1557 if let Some(last_message) = thread.messages.last_mut() {
1558 if last_message.role == Role::Assistant
1559 && !thread.tool_use.has_tool_results(last_message.id)
1560 {
1561 last_message.push_thinking(&chunk, signature);
1562 cx.emit(ThreadEvent::StreamedAssistantThinking(
1563 last_message.id,
1564 chunk,
1565 ));
1566 } else {
1567 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1568 // of a new Assistant response.
1569 //
1570 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1571 // will result in duplicating the text of the chunk in the rendered Markdown.
1572 request_assistant_message_id =
1573 Some(thread.insert_assistant_message(
1574 vec![MessageSegment::Thinking {
1575 text: chunk.to_string(),
1576 signature,
1577 }],
1578 cx,
1579 ));
1580 };
1581 }
1582 }
1583 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1584 let last_assistant_message_id = request_assistant_message_id
1585 .unwrap_or_else(|| {
1586 let new_assistant_message_id =
1587 thread.insert_assistant_message(vec![], cx);
1588 request_assistant_message_id =
1589 Some(new_assistant_message_id);
1590 new_assistant_message_id
1591 });
1592
1593 let tool_use_id = tool_use.id.clone();
1594 let streamed_input = if tool_use.is_input_complete {
1595 None
1596 } else {
1597 Some((&tool_use.input).clone())
1598 };
1599
1600 let ui_text = thread.tool_use.request_tool_use(
1601 last_assistant_message_id,
1602 tool_use,
1603 tool_use_metadata.clone(),
1604 cx,
1605 );
1606
1607 if let Some(input) = streamed_input {
1608 cx.emit(ThreadEvent::StreamedToolUse {
1609 tool_use_id,
1610 ui_text,
1611 input,
1612 });
1613 }
1614 }
1615 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1616 if let Some(completion) = thread
1617 .pending_completions
1618 .iter_mut()
1619 .find(|completion| completion.id == pending_completion_id)
1620 {
1621 match status_update {
1622 CompletionRequestStatus::Queued {
1623 position,
1624 } => {
1625 completion.queue_state = QueueState::Queued { position };
1626 }
1627 CompletionRequestStatus::Started => {
1628 completion.queue_state = QueueState::Started;
1629 }
1630 CompletionRequestStatus::Failed {
1631 code, message, request_id
1632 } => {
1633 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1634 }
1635 CompletionRequestStatus::UsageUpdated {
1636 amount, limit
1637 } => {
1638 let usage = RequestUsage { limit, amount: amount as i32 };
1639
1640 thread.last_usage = Some(usage);
1641 }
1642 CompletionRequestStatus::ToolUseLimitReached => {
1643 thread.tool_use_limit_reached = true;
1644 }
1645 }
1646 }
1647 }
1648 }
1649
1650 thread.touch_updated_at();
1651 cx.emit(ThreadEvent::StreamedCompletion);
1652 cx.notify();
1653
1654 thread.auto_capture_telemetry(cx);
1655 Ok(())
1656 })??;
1657
1658 smol::future::yield_now().await;
1659 }
1660
1661 thread.update(cx, |thread, cx| {
1662 thread.last_received_chunk_at = None;
1663 thread
1664 .pending_completions
1665 .retain(|completion| completion.id != pending_completion_id);
1666
1667 // If there is a response without tool use, summarize the message. Otherwise,
1668 // allow two tool uses before summarizing.
1669 if matches!(thread.summary, ThreadSummary::Pending)
1670 && thread.messages.len() >= 2
1671 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1672 {
1673 thread.summarize(cx);
1674 }
1675 })?;
1676
1677 anyhow::Ok(stop_reason)
1678 };
1679
1680 let result = stream_completion.await;
1681
1682 thread
1683 .update(cx, |thread, cx| {
1684 thread.finalize_pending_checkpoint(cx);
1685 match result.as_ref() {
1686 Ok(stop_reason) => match stop_reason {
1687 StopReason::ToolUse => {
1688 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1689 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1690 }
1691 StopReason::EndTurn | StopReason::MaxTokens => {
1692 thread.project.update(cx, |project, cx| {
1693 project.set_agent_location(None, cx);
1694 });
1695 }
1696 },
1697 Err(error) => {
1698 thread.project.update(cx, |project, cx| {
1699 project.set_agent_location(None, cx);
1700 });
1701
1702 if error.is::<PaymentRequiredError>() {
1703 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1704 } else if let Some(error) =
1705 error.downcast_ref::<ModelRequestLimitReachedError>()
1706 {
1707 cx.emit(ThreadEvent::ShowError(
1708 ThreadError::ModelRequestLimitReached { plan: error.plan },
1709 ));
1710 } else if let Some(known_error) =
1711 error.downcast_ref::<LanguageModelKnownError>()
1712 {
1713 match known_error {
1714 LanguageModelKnownError::ContextWindowLimitExceeded {
1715 tokens,
1716 } => {
1717 thread.exceeded_window_error = Some(ExceededWindowError {
1718 model_id: model.id(),
1719 token_count: *tokens,
1720 });
1721 cx.notify();
1722 }
1723 }
1724 } else {
1725 let error_message = error
1726 .chain()
1727 .map(|err| err.to_string())
1728 .collect::<Vec<_>>()
1729 .join("\n");
1730 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1731 header: "Error interacting with language model".into(),
1732 message: SharedString::from(error_message.clone()),
1733 }));
1734 }
1735
1736 thread.cancel_last_completion(window, cx);
1737 }
1738 }
1739 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1740
1741 if let Some((request_callback, (request, response_events))) = thread
1742 .request_callback
1743 .as_mut()
1744 .zip(request_callback_parameters.as_ref())
1745 {
1746 request_callback(request, response_events);
1747 }
1748
1749 thread.auto_capture_telemetry(cx);
1750
1751 if let Ok(initial_usage) = initial_token_usage {
1752 let usage = thread.cumulative_token_usage - initial_usage;
1753
1754 telemetry::event!(
1755 "Assistant Thread Completion",
1756 thread_id = thread.id().to_string(),
1757 prompt_id = prompt_id,
1758 model = model.telemetry_id(),
1759 model_provider = model.provider_id().to_string(),
1760 input_tokens = usage.input_tokens,
1761 output_tokens = usage.output_tokens,
1762 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1763 cache_read_input_tokens = usage.cache_read_input_tokens,
1764 );
1765 }
1766 })
1767 .ok();
1768 });
1769
1770 self.pending_completions.push(PendingCompletion {
1771 id: pending_completion_id,
1772 queue_state: QueueState::Sending,
1773 _task: task,
1774 });
1775 }
1776
1777 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1778 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1779 println!("No thread summary model");
1780 return;
1781 };
1782
1783 if !model.provider.is_authenticated(cx) {
1784 return;
1785 }
1786
1787 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1788 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1789 If the conversation is about a specific subject, include it in the title. \
1790 Be descriptive. DO NOT speak in the first person.";
1791
1792 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1793
1794 self.summary = ThreadSummary::Generating;
1795
1796 self.pending_summary = cx.spawn(async move |this, cx| {
1797 let result = async {
1798 let mut messages = model.model.stream_completion(request, &cx).await?;
1799
1800 let mut new_summary = String::new();
1801 while let Some(event) = messages.next().await {
1802 let Ok(event) = event else {
1803 continue;
1804 };
1805 let text = match event {
1806 LanguageModelCompletionEvent::Text(text) => text,
1807 LanguageModelCompletionEvent::StatusUpdate(
1808 CompletionRequestStatus::UsageUpdated { amount, limit },
1809 ) => {
1810 this.update(cx, |thread, _cx| {
1811 thread.last_usage = Some(RequestUsage {
1812 limit,
1813 amount: amount as i32,
1814 });
1815 })?;
1816 continue;
1817 }
1818 _ => continue,
1819 };
1820
1821 let mut lines = text.lines();
1822 new_summary.extend(lines.next());
1823
1824 // Stop if the LLM generated multiple lines.
1825 if lines.next().is_some() {
1826 break;
1827 }
1828 }
1829
1830 anyhow::Ok(new_summary)
1831 }
1832 .await;
1833
1834 this.update(cx, |this, cx| {
1835 match result {
1836 Ok(new_summary) => {
1837 if new_summary.is_empty() {
1838 this.summary = ThreadSummary::Error;
1839 } else {
1840 this.summary = ThreadSummary::Ready(new_summary.into());
1841 }
1842 }
1843 Err(err) => {
1844 this.summary = ThreadSummary::Error;
1845 log::error!("Failed to generate thread summary: {}", err);
1846 }
1847 }
1848 cx.emit(ThreadEvent::SummaryGenerated);
1849 })
1850 .log_err()?;
1851
1852 Some(())
1853 });
1854 }
1855
1856 pub fn start_generating_detailed_summary_if_needed(
1857 &mut self,
1858 thread_store: WeakEntity<ThreadStore>,
1859 cx: &mut Context<Self>,
1860 ) {
1861 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1862 return;
1863 };
1864
1865 match &*self.detailed_summary_rx.borrow() {
1866 DetailedSummaryState::Generating { message_id, .. }
1867 | DetailedSummaryState::Generated { message_id, .. }
1868 if *message_id == last_message_id =>
1869 {
1870 // Already up-to-date
1871 return;
1872 }
1873 _ => {}
1874 }
1875
1876 let Some(ConfiguredModel { model, provider }) =
1877 LanguageModelRegistry::read_global(cx).thread_summary_model()
1878 else {
1879 return;
1880 };
1881
1882 if !provider.is_authenticated(cx) {
1883 return;
1884 }
1885
1886 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1887 1. A brief overview of what was discussed\n\
1888 2. Key facts or information discovered\n\
1889 3. Outcomes or conclusions reached\n\
1890 4. Any action items or next steps if any\n\
1891 Format it in Markdown with headings and bullet points.";
1892
1893 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1894
1895 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1896 message_id: last_message_id,
1897 };
1898
1899 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1900 // be better to allow the old task to complete, but this would require logic for choosing
1901 // which result to prefer (the old task could complete after the new one, resulting in a
1902 // stale summary).
1903 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1904 let stream = model.stream_completion_text(request, &cx);
1905 let Some(mut messages) = stream.await.log_err() else {
1906 thread
1907 .update(cx, |thread, _cx| {
1908 *thread.detailed_summary_tx.borrow_mut() =
1909 DetailedSummaryState::NotGenerated;
1910 })
1911 .ok()?;
1912 return None;
1913 };
1914
1915 let mut new_detailed_summary = String::new();
1916
1917 while let Some(chunk) = messages.stream.next().await {
1918 if let Some(chunk) = chunk.log_err() {
1919 new_detailed_summary.push_str(&chunk);
1920 }
1921 }
1922
1923 thread
1924 .update(cx, |thread, _cx| {
1925 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1926 text: new_detailed_summary.into(),
1927 message_id: last_message_id,
1928 };
1929 })
1930 .ok()?;
1931
1932 // Save thread so its summary can be reused later
1933 if let Some(thread) = thread.upgrade() {
1934 if let Ok(Ok(save_task)) = cx.update(|cx| {
1935 thread_store
1936 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1937 }) {
1938 save_task.await.log_err();
1939 }
1940 }
1941
1942 Some(())
1943 });
1944 }
1945
1946 pub async fn wait_for_detailed_summary_or_text(
1947 this: &Entity<Self>,
1948 cx: &mut AsyncApp,
1949 ) -> Option<SharedString> {
1950 let mut detailed_summary_rx = this
1951 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1952 .ok()?;
1953 loop {
1954 match detailed_summary_rx.recv().await? {
1955 DetailedSummaryState::Generating { .. } => {}
1956 DetailedSummaryState::NotGenerated => {
1957 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1958 }
1959 DetailedSummaryState::Generated { text, .. } => return Some(text),
1960 }
1961 }
1962 }
1963
1964 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1965 self.detailed_summary_rx
1966 .borrow()
1967 .text()
1968 .unwrap_or_else(|| self.text().into())
1969 }
1970
1971 pub fn is_generating_detailed_summary(&self) -> bool {
1972 matches!(
1973 &*self.detailed_summary_rx.borrow(),
1974 DetailedSummaryState::Generating { .. }
1975 )
1976 }
1977
1978 pub fn use_pending_tools(
1979 &mut self,
1980 window: Option<AnyWindowHandle>,
1981 cx: &mut Context<Self>,
1982 model: Arc<dyn LanguageModel>,
1983 ) -> Vec<PendingToolUse> {
1984 self.auto_capture_telemetry(cx);
1985 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1986 let pending_tool_uses = self
1987 .tool_use
1988 .pending_tool_uses()
1989 .into_iter()
1990 .filter(|tool_use| tool_use.status.is_idle())
1991 .cloned()
1992 .collect::<Vec<_>>();
1993
1994 for tool_use in pending_tool_uses.iter() {
1995 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1996 if tool.needs_confirmation(&tool_use.input, cx)
1997 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1998 {
1999 self.tool_use.confirm_tool_use(
2000 tool_use.id.clone(),
2001 tool_use.ui_text.clone(),
2002 tool_use.input.clone(),
2003 request.clone(),
2004 tool,
2005 );
2006 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2007 } else {
2008 self.run_tool(
2009 tool_use.id.clone(),
2010 tool_use.ui_text.clone(),
2011 tool_use.input.clone(),
2012 request.clone(),
2013 tool,
2014 model.clone(),
2015 window,
2016 cx,
2017 );
2018 }
2019 } else {
2020 self.handle_hallucinated_tool_use(
2021 tool_use.id.clone(),
2022 tool_use.name.clone(),
2023 window,
2024 cx,
2025 );
2026 }
2027 }
2028
2029 pending_tool_uses
2030 }
2031
2032 pub fn handle_hallucinated_tool_use(
2033 &mut self,
2034 tool_use_id: LanguageModelToolUseId,
2035 hallucinated_tool_name: Arc<str>,
2036 window: Option<AnyWindowHandle>,
2037 cx: &mut Context<Thread>,
2038 ) {
2039 let available_tools = self.tools.read(cx).enabled_tools(cx);
2040
2041 let tool_list = available_tools
2042 .iter()
2043 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2044 .collect::<Vec<_>>()
2045 .join("\n");
2046
2047 let error_message = format!(
2048 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2049 hallucinated_tool_name, tool_list
2050 );
2051
2052 let pending_tool_use = self.tool_use.insert_tool_output(
2053 tool_use_id.clone(),
2054 hallucinated_tool_name,
2055 Err(anyhow!("Missing tool call: {error_message}")),
2056 self.configured_model.as_ref(),
2057 );
2058
2059 cx.emit(ThreadEvent::MissingToolUse {
2060 tool_use_id: tool_use_id.clone(),
2061 ui_text: error_message.into(),
2062 });
2063
2064 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2065 }
2066
2067 pub fn receive_invalid_tool_json(
2068 &mut self,
2069 tool_use_id: LanguageModelToolUseId,
2070 tool_name: Arc<str>,
2071 invalid_json: Arc<str>,
2072 error: String,
2073 window: Option<AnyWindowHandle>,
2074 cx: &mut Context<Thread>,
2075 ) {
2076 log::error!("The model returned invalid input JSON: {invalid_json}");
2077
2078 let pending_tool_use = self.tool_use.insert_tool_output(
2079 tool_use_id.clone(),
2080 tool_name,
2081 Err(anyhow!("Error parsing input JSON: {error}")),
2082 self.configured_model.as_ref(),
2083 );
2084 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2085 pending_tool_use.ui_text.clone()
2086 } else {
2087 log::error!(
2088 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2089 );
2090 format!("Unknown tool {}", tool_use_id).into()
2091 };
2092
2093 cx.emit(ThreadEvent::InvalidToolInput {
2094 tool_use_id: tool_use_id.clone(),
2095 ui_text,
2096 invalid_input_json: invalid_json,
2097 });
2098
2099 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2100 }
2101
2102 pub fn run_tool(
2103 &mut self,
2104 tool_use_id: LanguageModelToolUseId,
2105 ui_text: impl Into<SharedString>,
2106 input: serde_json::Value,
2107 request: Arc<LanguageModelRequest>,
2108 tool: Arc<dyn Tool>,
2109 model: Arc<dyn LanguageModel>,
2110 window: Option<AnyWindowHandle>,
2111 cx: &mut Context<Thread>,
2112 ) {
2113 let task =
2114 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2115 self.tool_use
2116 .run_pending_tool(tool_use_id, ui_text.into(), task);
2117 }
2118
2119 fn spawn_tool_use(
2120 &mut self,
2121 tool_use_id: LanguageModelToolUseId,
2122 request: Arc<LanguageModelRequest>,
2123 input: serde_json::Value,
2124 tool: Arc<dyn Tool>,
2125 model: Arc<dyn LanguageModel>,
2126 window: Option<AnyWindowHandle>,
2127 cx: &mut Context<Thread>,
2128 ) -> Task<()> {
2129 let tool_name: Arc<str> = tool.name().into();
2130
2131 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2132 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2133 } else {
2134 tool.run(
2135 input,
2136 request,
2137 self.project.clone(),
2138 self.action_log.clone(),
2139 model,
2140 window,
2141 cx,
2142 )
2143 };
2144
2145 // Store the card separately if it exists
2146 if let Some(card) = tool_result.card.clone() {
2147 self.tool_use
2148 .insert_tool_result_card(tool_use_id.clone(), card);
2149 }
2150
2151 cx.spawn({
2152 async move |thread: WeakEntity<Thread>, cx| {
2153 let output = tool_result.output.await;
2154
2155 thread
2156 .update(cx, |thread, cx| {
2157 let pending_tool_use = thread.tool_use.insert_tool_output(
2158 tool_use_id.clone(),
2159 tool_name,
2160 output,
2161 thread.configured_model.as_ref(),
2162 );
2163 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2164 })
2165 .ok();
2166 }
2167 })
2168 }
2169
2170 fn tool_finished(
2171 &mut self,
2172 tool_use_id: LanguageModelToolUseId,
2173 pending_tool_use: Option<PendingToolUse>,
2174 canceled: bool,
2175 window: Option<AnyWindowHandle>,
2176 cx: &mut Context<Self>,
2177 ) {
2178 if self.all_tools_finished() {
2179 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2180 if !canceled {
2181 self.send_to_model(model.clone(), window, cx);
2182 }
2183 self.auto_capture_telemetry(cx);
2184 }
2185 }
2186
2187 cx.emit(ThreadEvent::ToolFinished {
2188 tool_use_id,
2189 pending_tool_use,
2190 });
2191 }
2192
2193 /// Cancels the last pending completion, if there are any pending.
2194 ///
2195 /// Returns whether a completion was canceled.
2196 pub fn cancel_last_completion(
2197 &mut self,
2198 window: Option<AnyWindowHandle>,
2199 cx: &mut Context<Self>,
2200 ) -> bool {
2201 let mut canceled = self.pending_completions.pop().is_some();
2202
2203 for pending_tool_use in self.tool_use.cancel_pending() {
2204 canceled = true;
2205 self.tool_finished(
2206 pending_tool_use.id.clone(),
2207 Some(pending_tool_use),
2208 true,
2209 window,
2210 cx,
2211 );
2212 }
2213
2214 self.finalize_pending_checkpoint(cx);
2215
2216 if canceled {
2217 cx.emit(ThreadEvent::CompletionCanceled);
2218 }
2219
2220 canceled
2221 }
2222
2223 /// Signals that any in-progress editing should be canceled.
2224 ///
2225 /// This method is used to notify listeners (like ActiveThread) that
2226 /// they should cancel any editing operations.
2227 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2228 cx.emit(ThreadEvent::CancelEditing);
2229 }
2230
2231 pub fn feedback(&self) -> Option<ThreadFeedback> {
2232 self.feedback
2233 }
2234
2235 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2236 self.message_feedback.get(&message_id).copied()
2237 }
2238
2239 pub fn report_message_feedback(
2240 &mut self,
2241 message_id: MessageId,
2242 feedback: ThreadFeedback,
2243 cx: &mut Context<Self>,
2244 ) -> Task<Result<()>> {
2245 if self.message_feedback.get(&message_id) == Some(&feedback) {
2246 return Task::ready(Ok(()));
2247 }
2248
2249 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2250 let serialized_thread = self.serialize(cx);
2251 let thread_id = self.id().clone();
2252 let client = self.project.read(cx).client();
2253
2254 let enabled_tool_names: Vec<String> = self
2255 .tools()
2256 .read(cx)
2257 .enabled_tools(cx)
2258 .iter()
2259 .map(|tool| tool.name())
2260 .collect();
2261
2262 self.message_feedback.insert(message_id, feedback);
2263
2264 cx.notify();
2265
2266 let message_content = self
2267 .message(message_id)
2268 .map(|msg| msg.to_string())
2269 .unwrap_or_default();
2270
2271 cx.background_spawn(async move {
2272 let final_project_snapshot = final_project_snapshot.await;
2273 let serialized_thread = serialized_thread.await?;
2274 let thread_data =
2275 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2276
2277 let rating = match feedback {
2278 ThreadFeedback::Positive => "positive",
2279 ThreadFeedback::Negative => "negative",
2280 };
2281 telemetry::event!(
2282 "Assistant Thread Rated",
2283 rating,
2284 thread_id,
2285 enabled_tool_names,
2286 message_id = message_id.0,
2287 message_content,
2288 thread_data,
2289 final_project_snapshot
2290 );
2291 client.telemetry().flush_events().await;
2292
2293 Ok(())
2294 })
2295 }
2296
2297 pub fn report_feedback(
2298 &mut self,
2299 feedback: ThreadFeedback,
2300 cx: &mut Context<Self>,
2301 ) -> Task<Result<()>> {
2302 let last_assistant_message_id = self
2303 .messages
2304 .iter()
2305 .rev()
2306 .find(|msg| msg.role == Role::Assistant)
2307 .map(|msg| msg.id);
2308
2309 if let Some(message_id) = last_assistant_message_id {
2310 self.report_message_feedback(message_id, feedback, cx)
2311 } else {
2312 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2313 let serialized_thread = self.serialize(cx);
2314 let thread_id = self.id().clone();
2315 let client = self.project.read(cx).client();
2316 self.feedback = Some(feedback);
2317 cx.notify();
2318
2319 cx.background_spawn(async move {
2320 let final_project_snapshot = final_project_snapshot.await;
2321 let serialized_thread = serialized_thread.await?;
2322 let thread_data = serde_json::to_value(serialized_thread)
2323 .unwrap_or_else(|_| serde_json::Value::Null);
2324
2325 let rating = match feedback {
2326 ThreadFeedback::Positive => "positive",
2327 ThreadFeedback::Negative => "negative",
2328 };
2329 telemetry::event!(
2330 "Assistant Thread Rated",
2331 rating,
2332 thread_id,
2333 thread_data,
2334 final_project_snapshot
2335 );
2336 client.telemetry().flush_events().await;
2337
2338 Ok(())
2339 })
2340 }
2341 }
2342
2343 /// Create a snapshot of the current project state including git information and unsaved buffers.
2344 fn project_snapshot(
2345 project: Entity<Project>,
2346 cx: &mut Context<Self>,
2347 ) -> Task<Arc<ProjectSnapshot>> {
2348 let git_store = project.read(cx).git_store().clone();
2349 let worktree_snapshots: Vec<_> = project
2350 .read(cx)
2351 .visible_worktrees(cx)
2352 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2353 .collect();
2354
2355 cx.spawn(async move |_, cx| {
2356 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2357
2358 let mut unsaved_buffers = Vec::new();
2359 cx.update(|app_cx| {
2360 let buffer_store = project.read(app_cx).buffer_store();
2361 for buffer_handle in buffer_store.read(app_cx).buffers() {
2362 let buffer = buffer_handle.read(app_cx);
2363 if buffer.is_dirty() {
2364 if let Some(file) = buffer.file() {
2365 let path = file.path().to_string_lossy().to_string();
2366 unsaved_buffers.push(path);
2367 }
2368 }
2369 }
2370 })
2371 .ok();
2372
2373 Arc::new(ProjectSnapshot {
2374 worktree_snapshots,
2375 unsaved_buffer_paths: unsaved_buffers,
2376 timestamp: Utc::now(),
2377 })
2378 })
2379 }
2380
2381 fn worktree_snapshot(
2382 worktree: Entity<project::Worktree>,
2383 git_store: Entity<GitStore>,
2384 cx: &App,
2385 ) -> Task<WorktreeSnapshot> {
2386 cx.spawn(async move |cx| {
2387 // Get worktree path and snapshot
2388 let worktree_info = cx.update(|app_cx| {
2389 let worktree = worktree.read(app_cx);
2390 let path = worktree.abs_path().to_string_lossy().to_string();
2391 let snapshot = worktree.snapshot();
2392 (path, snapshot)
2393 });
2394
2395 let Ok((worktree_path, _snapshot)) = worktree_info else {
2396 return WorktreeSnapshot {
2397 worktree_path: String::new(),
2398 git_state: None,
2399 };
2400 };
2401
2402 let git_state = git_store
2403 .update(cx, |git_store, cx| {
2404 git_store
2405 .repositories()
2406 .values()
2407 .find(|repo| {
2408 repo.read(cx)
2409 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2410 .is_some()
2411 })
2412 .cloned()
2413 })
2414 .ok()
2415 .flatten()
2416 .map(|repo| {
2417 repo.update(cx, |repo, _| {
2418 let current_branch =
2419 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2420 repo.send_job(None, |state, _| async move {
2421 let RepositoryState::Local { backend, .. } = state else {
2422 return GitState {
2423 remote_url: None,
2424 head_sha: None,
2425 current_branch,
2426 diff: None,
2427 };
2428 };
2429
2430 let remote_url = backend.remote_url("origin");
2431 let head_sha = backend.head_sha().await;
2432 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2433
2434 GitState {
2435 remote_url,
2436 head_sha,
2437 current_branch,
2438 diff,
2439 }
2440 })
2441 })
2442 });
2443
2444 let git_state = match git_state {
2445 Some(git_state) => match git_state.ok() {
2446 Some(git_state) => git_state.await.ok(),
2447 None => None,
2448 },
2449 None => None,
2450 };
2451
2452 WorktreeSnapshot {
2453 worktree_path,
2454 git_state,
2455 }
2456 })
2457 }
2458
2459 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2460 let mut markdown = Vec::new();
2461
2462 let summary = self.summary().or_default();
2463 writeln!(markdown, "# {summary}\n")?;
2464
2465 for message in self.messages() {
2466 writeln!(
2467 markdown,
2468 "## {role}\n",
2469 role = match message.role {
2470 Role::User => "User",
2471 Role::Assistant => "Agent",
2472 Role::System => "System",
2473 }
2474 )?;
2475
2476 if !message.loaded_context.text.is_empty() {
2477 writeln!(markdown, "{}", message.loaded_context.text)?;
2478 }
2479
2480 if !message.loaded_context.images.is_empty() {
2481 writeln!(
2482 markdown,
2483 "\n{} images attached as context.\n",
2484 message.loaded_context.images.len()
2485 )?;
2486 }
2487
2488 for segment in &message.segments {
2489 match segment {
2490 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2491 MessageSegment::Thinking { text, .. } => {
2492 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2493 }
2494 MessageSegment::RedactedThinking(_) => {}
2495 }
2496 }
2497
2498 for tool_use in self.tool_uses_for_message(message.id, cx) {
2499 writeln!(
2500 markdown,
2501 "**Use Tool: {} ({})**",
2502 tool_use.name, tool_use.id
2503 )?;
2504 writeln!(markdown, "```json")?;
2505 writeln!(
2506 markdown,
2507 "{}",
2508 serde_json::to_string_pretty(&tool_use.input)?
2509 )?;
2510 writeln!(markdown, "```")?;
2511 }
2512
2513 for tool_result in self.tool_results_for_message(message.id) {
2514 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2515 if tool_result.is_error {
2516 write!(markdown, " (Error)")?;
2517 }
2518
2519 writeln!(markdown, "**\n")?;
2520 match &tool_result.content {
2521 LanguageModelToolResultContent::Text(text)
2522 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2523 text,
2524 ..
2525 }) => {
2526 writeln!(markdown, "{text}")?;
2527 }
2528 LanguageModelToolResultContent::Image(image) => {
2529 writeln!(markdown, "", image.source)?;
2530 }
2531 }
2532
2533 if let Some(output) = tool_result.output.as_ref() {
2534 writeln!(
2535 markdown,
2536 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2537 serde_json::to_string_pretty(output)?
2538 )?;
2539 }
2540 }
2541 }
2542
2543 Ok(String::from_utf8_lossy(&markdown).to_string())
2544 }
2545
2546 pub fn keep_edits_in_range(
2547 &mut self,
2548 buffer: Entity<language::Buffer>,
2549 buffer_range: Range<language::Anchor>,
2550 cx: &mut Context<Self>,
2551 ) {
2552 self.action_log.update(cx, |action_log, cx| {
2553 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2554 });
2555 }
2556
2557 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2558 self.action_log
2559 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2560 }
2561
2562 pub fn reject_edits_in_ranges(
2563 &mut self,
2564 buffer: Entity<language::Buffer>,
2565 buffer_ranges: Vec<Range<language::Anchor>>,
2566 cx: &mut Context<Self>,
2567 ) -> Task<Result<()>> {
2568 self.action_log.update(cx, |action_log, cx| {
2569 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2570 })
2571 }
2572
2573 pub fn action_log(&self) -> &Entity<ActionLog> {
2574 &self.action_log
2575 }
2576
2577 pub fn project(&self) -> &Entity<Project> {
2578 &self.project
2579 }
2580
2581 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2582 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2583 return;
2584 }
2585
2586 let now = Instant::now();
2587 if let Some(last) = self.last_auto_capture_at {
2588 if now.duration_since(last).as_secs() < 10 {
2589 return;
2590 }
2591 }
2592
2593 self.last_auto_capture_at = Some(now);
2594
2595 let thread_id = self.id().clone();
2596 let github_login = self
2597 .project
2598 .read(cx)
2599 .user_store()
2600 .read(cx)
2601 .current_user()
2602 .map(|user| user.github_login.clone());
2603 let client = self.project.read(cx).client();
2604 let serialize_task = self.serialize(cx);
2605
2606 cx.background_executor()
2607 .spawn(async move {
2608 if let Ok(serialized_thread) = serialize_task.await {
2609 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2610 telemetry::event!(
2611 "Agent Thread Auto-Captured",
2612 thread_id = thread_id.to_string(),
2613 thread_data = thread_data,
2614 auto_capture_reason = "tracked_user",
2615 github_login = github_login
2616 );
2617
2618 client.telemetry().flush_events().await;
2619 }
2620 }
2621 })
2622 .detach();
2623 }
2624
2625 pub fn cumulative_token_usage(&self) -> TokenUsage {
2626 self.cumulative_token_usage
2627 }
2628
2629 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2630 let Some(model) = self.configured_model.as_ref() else {
2631 return TotalTokenUsage::default();
2632 };
2633
2634 let max = model.model.max_token_count();
2635
2636 let index = self
2637 .messages
2638 .iter()
2639 .position(|msg| msg.id == message_id)
2640 .unwrap_or(0);
2641
2642 if index == 0 {
2643 return TotalTokenUsage { total: 0, max };
2644 }
2645
2646 let token_usage = &self
2647 .request_token_usage
2648 .get(index - 1)
2649 .cloned()
2650 .unwrap_or_default();
2651
2652 TotalTokenUsage {
2653 total: token_usage.total_tokens() as usize,
2654 max,
2655 }
2656 }
2657
2658 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2659 let model = self.configured_model.as_ref()?;
2660
2661 let max = model.model.max_token_count();
2662
2663 if let Some(exceeded_error) = &self.exceeded_window_error {
2664 if model.model.id() == exceeded_error.model_id {
2665 return Some(TotalTokenUsage {
2666 total: exceeded_error.token_count,
2667 max,
2668 });
2669 }
2670 }
2671
2672 let total = self
2673 .token_usage_at_last_message()
2674 .unwrap_or_default()
2675 .total_tokens() as usize;
2676
2677 Some(TotalTokenUsage { total, max })
2678 }
2679
2680 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2681 self.request_token_usage
2682 .get(self.messages.len().saturating_sub(1))
2683 .or_else(|| self.request_token_usage.last())
2684 .cloned()
2685 }
2686
2687 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2688 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2689 self.request_token_usage
2690 .resize(self.messages.len(), placeholder);
2691
2692 if let Some(last) = self.request_token_usage.last_mut() {
2693 *last = token_usage;
2694 }
2695 }
2696
2697 pub fn deny_tool_use(
2698 &mut self,
2699 tool_use_id: LanguageModelToolUseId,
2700 tool_name: Arc<str>,
2701 window: Option<AnyWindowHandle>,
2702 cx: &mut Context<Self>,
2703 ) {
2704 let err = Err(anyhow::anyhow!(
2705 "Permission to run tool action denied by user"
2706 ));
2707
2708 self.tool_use.insert_tool_output(
2709 tool_use_id.clone(),
2710 tool_name,
2711 err,
2712 self.configured_model.as_ref(),
2713 );
2714 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2715 }
2716}
2717
2718#[derive(Debug, Clone, Error)]
2719pub enum ThreadError {
2720 #[error("Payment required")]
2721 PaymentRequired,
2722 #[error("Model request limit reached")]
2723 ModelRequestLimitReached { plan: Plan },
2724 #[error("Message {header}: {message}")]
2725 Message {
2726 header: SharedString,
2727 message: SharedString,
2728 },
2729}
2730
2731#[derive(Debug, Clone)]
2732pub enum ThreadEvent {
2733 ShowError(ThreadError),
2734 StreamedCompletion,
2735 ReceivedTextChunk,
2736 NewRequest,
2737 StreamedAssistantText(MessageId, String),
2738 StreamedAssistantThinking(MessageId, String),
2739 StreamedToolUse {
2740 tool_use_id: LanguageModelToolUseId,
2741 ui_text: Arc<str>,
2742 input: serde_json::Value,
2743 },
2744 MissingToolUse {
2745 tool_use_id: LanguageModelToolUseId,
2746 ui_text: Arc<str>,
2747 },
2748 InvalidToolInput {
2749 tool_use_id: LanguageModelToolUseId,
2750 ui_text: Arc<str>,
2751 invalid_input_json: Arc<str>,
2752 },
2753 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2754 MessageAdded(MessageId),
2755 MessageEdited(MessageId),
2756 MessageDeleted(MessageId),
2757 SummaryGenerated,
2758 SummaryChanged,
2759 UsePendingTools {
2760 tool_uses: Vec<PendingToolUse>,
2761 },
2762 ToolFinished {
2763 #[allow(unused)]
2764 tool_use_id: LanguageModelToolUseId,
2765 /// The pending tool use that corresponds to this tool.
2766 pending_tool_use: Option<PendingToolUse>,
2767 },
2768 CheckpointChanged,
2769 ToolConfirmationNeeded,
2770 CancelEditing,
2771 CompletionCanceled,
2772}
2773
2774impl EventEmitter<ThreadEvent> for Thread {}
2775
2776struct PendingCompletion {
2777 id: usize,
2778 queue_state: QueueState,
2779 _task: Task<()>,
2780}
2781
2782#[cfg(test)]
2783mod tests {
2784 use super::*;
2785 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2786 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2787 use assistant_tool::ToolRegistry;
2788 use editor::EditorSettings;
2789 use gpui::TestAppContext;
2790 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2791 use project::{FakeFs, Project};
2792 use prompt_store::PromptBuilder;
2793 use serde_json::json;
2794 use settings::{Settings, SettingsStore};
2795 use std::sync::Arc;
2796 use theme::ThemeSettings;
2797 use util::path;
2798 use workspace::Workspace;
2799
2800 #[gpui::test]
2801 async fn test_message_with_context(cx: &mut TestAppContext) {
2802 init_test_settings(cx);
2803
2804 let project = create_test_project(
2805 cx,
2806 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2807 )
2808 .await;
2809
2810 let (_workspace, _thread_store, thread, context_store, model) =
2811 setup_test_environment(cx, project.clone()).await;
2812
2813 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2814 .await
2815 .unwrap();
2816
2817 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2818 let loaded_context = cx
2819 .update(|cx| load_context(vec![context], &project, &None, cx))
2820 .await;
2821
2822 // Insert user message with context
2823 let message_id = thread.update(cx, |thread, cx| {
2824 thread.insert_user_message(
2825 "Please explain this code",
2826 loaded_context,
2827 None,
2828 Vec::new(),
2829 cx,
2830 )
2831 });
2832
2833 // Check content and context in message object
2834 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2835
2836 // Use different path format strings based on platform for the test
2837 #[cfg(windows)]
2838 let path_part = r"test\code.rs";
2839 #[cfg(not(windows))]
2840 let path_part = "test/code.rs";
2841
2842 let expected_context = format!(
2843 r#"
2844<context>
2845The following items were attached by the user. They are up-to-date and don't need to be re-read.
2846
2847<files>
2848```rs {path_part}
2849fn main() {{
2850 println!("Hello, world!");
2851}}
2852```
2853</files>
2854</context>
2855"#
2856 );
2857
2858 assert_eq!(message.role, Role::User);
2859 assert_eq!(message.segments.len(), 1);
2860 assert_eq!(
2861 message.segments[0],
2862 MessageSegment::Text("Please explain this code".to_string())
2863 );
2864 assert_eq!(message.loaded_context.text, expected_context);
2865
2866 // Check message in request
2867 let request = thread.update(cx, |thread, cx| {
2868 thread.to_completion_request(model.clone(), cx)
2869 });
2870
2871 assert_eq!(request.messages.len(), 2);
2872 let expected_full_message = format!("{}Please explain this code", expected_context);
2873 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2874 }
2875
2876 #[gpui::test]
2877 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2878 init_test_settings(cx);
2879
2880 let project = create_test_project(
2881 cx,
2882 json!({
2883 "file1.rs": "fn function1() {}\n",
2884 "file2.rs": "fn function2() {}\n",
2885 "file3.rs": "fn function3() {}\n",
2886 "file4.rs": "fn function4() {}\n",
2887 }),
2888 )
2889 .await;
2890
2891 let (_, _thread_store, thread, context_store, model) =
2892 setup_test_environment(cx, project.clone()).await;
2893
2894 // First message with context 1
2895 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2896 .await
2897 .unwrap();
2898 let new_contexts = context_store.update(cx, |store, cx| {
2899 store.new_context_for_thread(thread.read(cx), None)
2900 });
2901 assert_eq!(new_contexts.len(), 1);
2902 let loaded_context = cx
2903 .update(|cx| load_context(new_contexts, &project, &None, cx))
2904 .await;
2905 let message1_id = thread.update(cx, |thread, cx| {
2906 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2907 });
2908
2909 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2910 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2911 .await
2912 .unwrap();
2913 let new_contexts = context_store.update(cx, |store, cx| {
2914 store.new_context_for_thread(thread.read(cx), None)
2915 });
2916 assert_eq!(new_contexts.len(), 1);
2917 let loaded_context = cx
2918 .update(|cx| load_context(new_contexts, &project, &None, cx))
2919 .await;
2920 let message2_id = thread.update(cx, |thread, cx| {
2921 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2922 });
2923
2924 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2925 //
2926 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2927 .await
2928 .unwrap();
2929 let new_contexts = context_store.update(cx, |store, cx| {
2930 store.new_context_for_thread(thread.read(cx), None)
2931 });
2932 assert_eq!(new_contexts.len(), 1);
2933 let loaded_context = cx
2934 .update(|cx| load_context(new_contexts, &project, &None, cx))
2935 .await;
2936 let message3_id = thread.update(cx, |thread, cx| {
2937 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2938 });
2939
2940 // Check what contexts are included in each message
2941 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2942 (
2943 thread.message(message1_id).unwrap().clone(),
2944 thread.message(message2_id).unwrap().clone(),
2945 thread.message(message3_id).unwrap().clone(),
2946 )
2947 });
2948
2949 // First message should include context 1
2950 assert!(message1.loaded_context.text.contains("file1.rs"));
2951
2952 // Second message should include only context 2 (not 1)
2953 assert!(!message2.loaded_context.text.contains("file1.rs"));
2954 assert!(message2.loaded_context.text.contains("file2.rs"));
2955
2956 // Third message should include only context 3 (not 1 or 2)
2957 assert!(!message3.loaded_context.text.contains("file1.rs"));
2958 assert!(!message3.loaded_context.text.contains("file2.rs"));
2959 assert!(message3.loaded_context.text.contains("file3.rs"));
2960
2961 // Check entire request to make sure all contexts are properly included
2962 let request = thread.update(cx, |thread, cx| {
2963 thread.to_completion_request(model.clone(), cx)
2964 });
2965
2966 // The request should contain all 3 messages
2967 assert_eq!(request.messages.len(), 4);
2968
2969 // Check that the contexts are properly formatted in each message
2970 assert!(request.messages[1].string_contents().contains("file1.rs"));
2971 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2972 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2973
2974 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2975 assert!(request.messages[2].string_contents().contains("file2.rs"));
2976 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2977
2978 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2979 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2980 assert!(request.messages[3].string_contents().contains("file3.rs"));
2981
2982 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2983 .await
2984 .unwrap();
2985 let new_contexts = context_store.update(cx, |store, cx| {
2986 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2987 });
2988 assert_eq!(new_contexts.len(), 3);
2989 let loaded_context = cx
2990 .update(|cx| load_context(new_contexts, &project, &None, cx))
2991 .await
2992 .loaded_context;
2993
2994 assert!(!loaded_context.text.contains("file1.rs"));
2995 assert!(loaded_context.text.contains("file2.rs"));
2996 assert!(loaded_context.text.contains("file3.rs"));
2997 assert!(loaded_context.text.contains("file4.rs"));
2998
2999 let new_contexts = context_store.update(cx, |store, cx| {
3000 // Remove file4.rs
3001 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3002 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3003 });
3004 assert_eq!(new_contexts.len(), 2);
3005 let loaded_context = cx
3006 .update(|cx| load_context(new_contexts, &project, &None, cx))
3007 .await
3008 .loaded_context;
3009
3010 assert!(!loaded_context.text.contains("file1.rs"));
3011 assert!(loaded_context.text.contains("file2.rs"));
3012 assert!(loaded_context.text.contains("file3.rs"));
3013 assert!(!loaded_context.text.contains("file4.rs"));
3014
3015 let new_contexts = context_store.update(cx, |store, cx| {
3016 // Remove file3.rs
3017 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3018 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3019 });
3020 assert_eq!(new_contexts.len(), 1);
3021 let loaded_context = cx
3022 .update(|cx| load_context(new_contexts, &project, &None, cx))
3023 .await
3024 .loaded_context;
3025
3026 assert!(!loaded_context.text.contains("file1.rs"));
3027 assert!(loaded_context.text.contains("file2.rs"));
3028 assert!(!loaded_context.text.contains("file3.rs"));
3029 assert!(!loaded_context.text.contains("file4.rs"));
3030 }
3031
3032 #[gpui::test]
3033 async fn test_message_without_files(cx: &mut TestAppContext) {
3034 init_test_settings(cx);
3035
3036 let project = create_test_project(
3037 cx,
3038 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3039 )
3040 .await;
3041
3042 let (_, _thread_store, thread, _context_store, model) =
3043 setup_test_environment(cx, project.clone()).await;
3044
3045 // Insert user message without any context (empty context vector)
3046 let message_id = thread.update(cx, |thread, cx| {
3047 thread.insert_user_message(
3048 "What is the best way to learn Rust?",
3049 ContextLoadResult::default(),
3050 None,
3051 Vec::new(),
3052 cx,
3053 )
3054 });
3055
3056 // Check content and context in message object
3057 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3058
3059 // Context should be empty when no files are included
3060 assert_eq!(message.role, Role::User);
3061 assert_eq!(message.segments.len(), 1);
3062 assert_eq!(
3063 message.segments[0],
3064 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3065 );
3066 assert_eq!(message.loaded_context.text, "");
3067
3068 // Check message in request
3069 let request = thread.update(cx, |thread, cx| {
3070 thread.to_completion_request(model.clone(), cx)
3071 });
3072
3073 assert_eq!(request.messages.len(), 2);
3074 assert_eq!(
3075 request.messages[1].string_contents(),
3076 "What is the best way to learn Rust?"
3077 );
3078
3079 // Add second message, also without context
3080 let message2_id = thread.update(cx, |thread, cx| {
3081 thread.insert_user_message(
3082 "Are there any good books?",
3083 ContextLoadResult::default(),
3084 None,
3085 Vec::new(),
3086 cx,
3087 )
3088 });
3089
3090 let message2 =
3091 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3092 assert_eq!(message2.loaded_context.text, "");
3093
3094 // Check that both messages appear in the request
3095 let request = thread.update(cx, |thread, cx| {
3096 thread.to_completion_request(model.clone(), cx)
3097 });
3098
3099 assert_eq!(request.messages.len(), 3);
3100 assert_eq!(
3101 request.messages[1].string_contents(),
3102 "What is the best way to learn Rust?"
3103 );
3104 assert_eq!(
3105 request.messages[2].string_contents(),
3106 "Are there any good books?"
3107 );
3108 }
3109
3110 #[gpui::test]
3111 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3112 init_test_settings(cx);
3113
3114 let project = create_test_project(
3115 cx,
3116 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3117 )
3118 .await;
3119
3120 let (_workspace, _thread_store, thread, context_store, model) =
3121 setup_test_environment(cx, project.clone()).await;
3122
3123 // Open buffer and add it to context
3124 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3125 .await
3126 .unwrap();
3127
3128 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3129 let loaded_context = cx
3130 .update(|cx| load_context(vec![context], &project, &None, cx))
3131 .await;
3132
3133 // Insert user message with the buffer as context
3134 thread.update(cx, |thread, cx| {
3135 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3136 });
3137
3138 // Create a request and check that it doesn't have a stale buffer warning yet
3139 let initial_request = thread.update(cx, |thread, cx| {
3140 thread.to_completion_request(model.clone(), cx)
3141 });
3142
3143 // Make sure we don't have a stale file warning yet
3144 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3145 msg.string_contents()
3146 .contains("These files changed since last read:")
3147 });
3148 assert!(
3149 !has_stale_warning,
3150 "Should not have stale buffer warning before buffer is modified"
3151 );
3152
3153 // Modify the buffer
3154 buffer.update(cx, |buffer, cx| {
3155 // Find a position at the end of line 1
3156 buffer.edit(
3157 [(1..1, "\n println!(\"Added a new line\");\n")],
3158 None,
3159 cx,
3160 );
3161 });
3162
3163 // Insert another user message without context
3164 thread.update(cx, |thread, cx| {
3165 thread.insert_user_message(
3166 "What does the code do now?",
3167 ContextLoadResult::default(),
3168 None,
3169 Vec::new(),
3170 cx,
3171 )
3172 });
3173
3174 // Create a new request and check for the stale buffer warning
3175 let new_request = thread.update(cx, |thread, cx| {
3176 thread.to_completion_request(model.clone(), cx)
3177 });
3178
3179 // We should have a stale file warning as the last message
3180 let last_message = new_request
3181 .messages
3182 .last()
3183 .expect("Request should have messages");
3184
3185 // The last message should be the stale buffer notification
3186 assert_eq!(last_message.role, Role::User);
3187
3188 // Check the exact content of the message
3189 let expected_content = "These files changed since last read:\n- code.rs\n";
3190 assert_eq!(
3191 last_message.string_contents(),
3192 expected_content,
3193 "Last message should be exactly the stale buffer notification"
3194 );
3195 }
3196
3197 #[gpui::test]
3198 async fn test_temperature_setting(cx: &mut TestAppContext) {
3199 init_test_settings(cx);
3200
3201 let project = create_test_project(
3202 cx,
3203 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3204 )
3205 .await;
3206
3207 let (_workspace, _thread_store, thread, _context_store, model) =
3208 setup_test_environment(cx, project.clone()).await;
3209
3210 // Both model and provider
3211 cx.update(|cx| {
3212 AssistantSettings::override_global(
3213 AssistantSettings {
3214 model_parameters: vec![LanguageModelParameters {
3215 provider: Some(model.provider_id().0.to_string().into()),
3216 model: Some(model.id().0.clone()),
3217 temperature: Some(0.66),
3218 }],
3219 ..AssistantSettings::get_global(cx).clone()
3220 },
3221 cx,
3222 );
3223 });
3224
3225 let request = thread.update(cx, |thread, cx| {
3226 thread.to_completion_request(model.clone(), cx)
3227 });
3228 assert_eq!(request.temperature, Some(0.66));
3229
3230 // Only model
3231 cx.update(|cx| {
3232 AssistantSettings::override_global(
3233 AssistantSettings {
3234 model_parameters: vec![LanguageModelParameters {
3235 provider: None,
3236 model: Some(model.id().0.clone()),
3237 temperature: Some(0.66),
3238 }],
3239 ..AssistantSettings::get_global(cx).clone()
3240 },
3241 cx,
3242 );
3243 });
3244
3245 let request = thread.update(cx, |thread, cx| {
3246 thread.to_completion_request(model.clone(), cx)
3247 });
3248 assert_eq!(request.temperature, Some(0.66));
3249
3250 // Only provider
3251 cx.update(|cx| {
3252 AssistantSettings::override_global(
3253 AssistantSettings {
3254 model_parameters: vec![LanguageModelParameters {
3255 provider: Some(model.provider_id().0.to_string().into()),
3256 model: None,
3257 temperature: Some(0.66),
3258 }],
3259 ..AssistantSettings::get_global(cx).clone()
3260 },
3261 cx,
3262 );
3263 });
3264
3265 let request = thread.update(cx, |thread, cx| {
3266 thread.to_completion_request(model.clone(), cx)
3267 });
3268 assert_eq!(request.temperature, Some(0.66));
3269
3270 // Same model name, different provider
3271 cx.update(|cx| {
3272 AssistantSettings::override_global(
3273 AssistantSettings {
3274 model_parameters: vec![LanguageModelParameters {
3275 provider: Some("anthropic".into()),
3276 model: Some(model.id().0.clone()),
3277 temperature: Some(0.66),
3278 }],
3279 ..AssistantSettings::get_global(cx).clone()
3280 },
3281 cx,
3282 );
3283 });
3284
3285 let request = thread.update(cx, |thread, cx| {
3286 thread.to_completion_request(model.clone(), cx)
3287 });
3288 assert_eq!(request.temperature, None);
3289 }
3290
3291 #[gpui::test]
3292 async fn test_thread_summary(cx: &mut TestAppContext) {
3293 init_test_settings(cx);
3294
3295 let project = create_test_project(cx, json!({})).await;
3296
3297 let (_, _thread_store, thread, _context_store, model) =
3298 setup_test_environment(cx, project.clone()).await;
3299
3300 // Initial state should be pending
3301 thread.read_with(cx, |thread, _| {
3302 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3303 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3304 });
3305
3306 // Manually setting the summary should not be allowed in this state
3307 thread.update(cx, |thread, cx| {
3308 thread.set_summary("This should not work", cx);
3309 });
3310
3311 thread.read_with(cx, |thread, _| {
3312 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3313 });
3314
3315 // Send a message
3316 thread.update(cx, |thread, cx| {
3317 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3318 thread.send_to_model(model.clone(), None, cx);
3319 });
3320
3321 let fake_model = model.as_fake();
3322 simulate_successful_response(&fake_model, cx);
3323
3324 // Should start generating summary when there are >= 2 messages
3325 thread.read_with(cx, |thread, _| {
3326 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3327 });
3328
3329 // Should not be able to set the summary while generating
3330 thread.update(cx, |thread, cx| {
3331 thread.set_summary("This should not work either", cx);
3332 });
3333
3334 thread.read_with(cx, |thread, _| {
3335 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3336 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3337 });
3338
3339 cx.run_until_parked();
3340 fake_model.stream_last_completion_response("Brief".into());
3341 fake_model.stream_last_completion_response(" Introduction".into());
3342 fake_model.end_last_completion_stream();
3343 cx.run_until_parked();
3344
3345 // Summary should be set
3346 thread.read_with(cx, |thread, _| {
3347 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3348 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3349 });
3350
3351 // Now we should be able to set a summary
3352 thread.update(cx, |thread, cx| {
3353 thread.set_summary("Brief Intro", cx);
3354 });
3355
3356 thread.read_with(cx, |thread, _| {
3357 assert_eq!(thread.summary().or_default(), "Brief Intro");
3358 });
3359
3360 // Test setting an empty summary (should default to DEFAULT)
3361 thread.update(cx, |thread, cx| {
3362 thread.set_summary("", cx);
3363 });
3364
3365 thread.read_with(cx, |thread, _| {
3366 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3367 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3368 });
3369 }
3370
3371 #[gpui::test]
3372 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3373 init_test_settings(cx);
3374
3375 let project = create_test_project(cx, json!({})).await;
3376
3377 let (_, _thread_store, thread, _context_store, model) =
3378 setup_test_environment(cx, project.clone()).await;
3379
3380 test_summarize_error(&model, &thread, cx);
3381
3382 // Now we should be able to set a summary
3383 thread.update(cx, |thread, cx| {
3384 thread.set_summary("Brief Intro", cx);
3385 });
3386
3387 thread.read_with(cx, |thread, _| {
3388 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3389 assert_eq!(thread.summary().or_default(), "Brief Intro");
3390 });
3391 }
3392
3393 #[gpui::test]
3394 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3395 init_test_settings(cx);
3396
3397 let project = create_test_project(cx, json!({})).await;
3398
3399 let (_, _thread_store, thread, _context_store, model) =
3400 setup_test_environment(cx, project.clone()).await;
3401
3402 test_summarize_error(&model, &thread, cx);
3403
3404 // Sending another message should not trigger another summarize request
3405 thread.update(cx, |thread, cx| {
3406 thread.insert_user_message(
3407 "How are you?",
3408 ContextLoadResult::default(),
3409 None,
3410 vec![],
3411 cx,
3412 );
3413 thread.send_to_model(model.clone(), None, cx);
3414 });
3415
3416 let fake_model = model.as_fake();
3417 simulate_successful_response(&fake_model, cx);
3418
3419 thread.read_with(cx, |thread, _| {
3420 // State is still Error, not Generating
3421 assert!(matches!(thread.summary(), ThreadSummary::Error));
3422 });
3423
3424 // But the summarize request can be invoked manually
3425 thread.update(cx, |thread, cx| {
3426 thread.summarize(cx);
3427 });
3428
3429 thread.read_with(cx, |thread, _| {
3430 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3431 });
3432
3433 cx.run_until_parked();
3434 fake_model.stream_last_completion_response("A successful summary".into());
3435 fake_model.end_last_completion_stream();
3436 cx.run_until_parked();
3437
3438 thread.read_with(cx, |thread, _| {
3439 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3440 assert_eq!(thread.summary().or_default(), "A successful summary");
3441 });
3442 }
3443
3444 fn test_summarize_error(
3445 model: &Arc<dyn LanguageModel>,
3446 thread: &Entity<Thread>,
3447 cx: &mut TestAppContext,
3448 ) {
3449 thread.update(cx, |thread, cx| {
3450 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3451 thread.send_to_model(model.clone(), None, cx);
3452 });
3453
3454 let fake_model = model.as_fake();
3455 simulate_successful_response(&fake_model, cx);
3456
3457 thread.read_with(cx, |thread, _| {
3458 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3459 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3460 });
3461
3462 // Simulate summary request ending
3463 cx.run_until_parked();
3464 fake_model.end_last_completion_stream();
3465 cx.run_until_parked();
3466
3467 // State is set to Error and default message
3468 thread.read_with(cx, |thread, _| {
3469 assert!(matches!(thread.summary(), ThreadSummary::Error));
3470 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3471 });
3472 }
3473
3474 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3475 cx.run_until_parked();
3476 fake_model.stream_last_completion_response("Assistant response".into());
3477 fake_model.end_last_completion_stream();
3478 cx.run_until_parked();
3479 }
3480
3481 fn init_test_settings(cx: &mut TestAppContext) {
3482 cx.update(|cx| {
3483 let settings_store = SettingsStore::test(cx);
3484 cx.set_global(settings_store);
3485 language::init(cx);
3486 Project::init_settings(cx);
3487 AssistantSettings::register(cx);
3488 prompt_store::init(cx);
3489 thread_store::init(cx);
3490 workspace::init_settings(cx);
3491 language_model::init_settings(cx);
3492 ThemeSettings::register(cx);
3493 EditorSettings::register(cx);
3494 ToolRegistry::default_global(cx);
3495 });
3496 }
3497
3498 // Helper to create a test project with test files
3499 async fn create_test_project(
3500 cx: &mut TestAppContext,
3501 files: serde_json::Value,
3502 ) -> Entity<Project> {
3503 let fs = FakeFs::new(cx.executor());
3504 fs.insert_tree(path!("/test"), files).await;
3505 Project::test(fs, [path!("/test").as_ref()], cx).await
3506 }
3507
3508 async fn setup_test_environment(
3509 cx: &mut TestAppContext,
3510 project: Entity<Project>,
3511 ) -> (
3512 Entity<Workspace>,
3513 Entity<ThreadStore>,
3514 Entity<Thread>,
3515 Entity<ContextStore>,
3516 Arc<dyn LanguageModel>,
3517 ) {
3518 let (workspace, cx) =
3519 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3520
3521 let thread_store = cx
3522 .update(|_, cx| {
3523 ThreadStore::load(
3524 project.clone(),
3525 cx.new(|_| ToolWorkingSet::default()),
3526 None,
3527 Arc::new(PromptBuilder::new(None).unwrap()),
3528 cx,
3529 )
3530 })
3531 .await
3532 .unwrap();
3533
3534 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3535 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3536
3537 let provider = Arc::new(FakeLanguageModelProvider);
3538 let model = provider.test_model();
3539 let model: Arc<dyn LanguageModel> = Arc::new(model);
3540
3541 cx.update(|_, cx| {
3542 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3543 registry.set_default_model(
3544 Some(ConfiguredModel {
3545 provider: provider.clone(),
3546 model: model.clone(),
3547 }),
3548 cx,
3549 );
3550 registry.set_thread_summary_model(
3551 Some(ConfiguredModel {
3552 provider,
3553 model: model.clone(),
3554 }),
3555 cx,
3556 );
3557 })
3558 });
3559
3560 (workspace, thread_store, thread, context_store, model)
3561 }
3562
3563 async fn add_file_to_context(
3564 project: &Entity<Project>,
3565 context_store: &Entity<ContextStore>,
3566 path: &str,
3567 cx: &mut TestAppContext,
3568 ) -> Result<Entity<language::Buffer>> {
3569 let buffer_path = project
3570 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3571 .unwrap();
3572
3573 let buffer = project
3574 .update(cx, |project, cx| {
3575 project.open_buffer(buffer_path.clone(), cx)
3576 })
3577 .await
3578 .unwrap();
3579
3580 context_store.update(cx, |context_store, cx| {
3581 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3582 });
3583
3584 Ok(buffer)
3585 }
3586}