1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: &mut Window,
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 Some(&self.tool_use.tool_result(id)?.content)
884 }
885
886 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
887 self.tool_use.tool_result_card(id).cloned()
888 }
889
890 /// Return tools that are both enabled and supported by the model
891 pub fn available_tools(
892 &self,
893 cx: &App,
894 model: Arc<dyn LanguageModel>,
895 ) -> Vec<LanguageModelRequestTool> {
896 if model.supports_tools() {
897 self.tools()
898 .read(cx)
899 .enabled_tools(cx)
900 .into_iter()
901 .filter_map(|tool| {
902 // Skip tools that cannot be supported
903 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
904 Some(LanguageModelRequestTool {
905 name: tool.name(),
906 description: tool.description(),
907 input_schema,
908 })
909 })
910 .collect()
911 } else {
912 Vec::default()
913 }
914 }
915
916 pub fn insert_user_message(
917 &mut self,
918 text: impl Into<String>,
919 loaded_context: ContextLoadResult,
920 git_checkpoint: Option<GitStoreCheckpoint>,
921 creases: Vec<MessageCrease>,
922 cx: &mut Context<Self>,
923 ) -> MessageId {
924 if !loaded_context.referenced_buffers.is_empty() {
925 self.action_log.update(cx, |log, cx| {
926 for buffer in loaded_context.referenced_buffers {
927 log.buffer_read(buffer, cx);
928 }
929 });
930 }
931
932 let message_id = self.insert_message(
933 Role::User,
934 vec![MessageSegment::Text(text.into())],
935 loaded_context.loaded_context,
936 creases,
937 cx,
938 );
939
940 if let Some(git_checkpoint) = git_checkpoint {
941 self.pending_checkpoint = Some(ThreadCheckpoint {
942 message_id,
943 git_checkpoint,
944 });
945 }
946
947 self.auto_capture_telemetry(cx);
948
949 message_id
950 }
951
952 pub fn insert_assistant_message(
953 &mut self,
954 segments: Vec<MessageSegment>,
955 cx: &mut Context<Self>,
956 ) -> MessageId {
957 self.insert_message(
958 Role::Assistant,
959 segments,
960 LoadedContext::default(),
961 Vec::new(),
962 cx,
963 )
964 }
965
966 pub fn insert_message(
967 &mut self,
968 role: Role,
969 segments: Vec<MessageSegment>,
970 loaded_context: LoadedContext,
971 creases: Vec<MessageCrease>,
972 cx: &mut Context<Self>,
973 ) -> MessageId {
974 let id = self.next_message_id.post_inc();
975 self.messages.push(Message {
976 id,
977 role,
978 segments,
979 loaded_context,
980 creases,
981 });
982 self.touch_updated_at();
983 cx.emit(ThreadEvent::MessageAdded(id));
984 id
985 }
986
987 pub fn edit_message(
988 &mut self,
989 id: MessageId,
990 new_role: Role,
991 new_segments: Vec<MessageSegment>,
992 loaded_context: Option<LoadedContext>,
993 cx: &mut Context<Self>,
994 ) -> bool {
995 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
996 return false;
997 };
998 message.role = new_role;
999 message.segments = new_segments;
1000 if let Some(context) = loaded_context {
1001 message.loaded_context = context;
1002 }
1003 self.touch_updated_at();
1004 cx.emit(ThreadEvent::MessageEdited(id));
1005 true
1006 }
1007
1008 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1009 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1010 return false;
1011 };
1012 self.messages.remove(index);
1013 self.touch_updated_at();
1014 cx.emit(ThreadEvent::MessageDeleted(id));
1015 true
1016 }
1017
1018 /// Returns the representation of this [`Thread`] in a textual form.
1019 ///
1020 /// This is the representation we use when attaching a thread as context to another thread.
1021 pub fn text(&self) -> String {
1022 let mut text = String::new();
1023
1024 for message in &self.messages {
1025 text.push_str(match message.role {
1026 language_model::Role::User => "User:",
1027 language_model::Role::Assistant => "Agent:",
1028 language_model::Role::System => "System:",
1029 });
1030 text.push('\n');
1031
1032 for segment in &message.segments {
1033 match segment {
1034 MessageSegment::Text(content) => text.push_str(content),
1035 MessageSegment::Thinking { text: content, .. } => {
1036 text.push_str(&format!("<think>{}</think>", content))
1037 }
1038 MessageSegment::RedactedThinking(_) => {}
1039 }
1040 }
1041 text.push('\n');
1042 }
1043
1044 text
1045 }
1046
1047 /// Serializes this thread into a format for storage or telemetry.
1048 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1049 let initial_project_snapshot = self.initial_project_snapshot.clone();
1050 cx.spawn(async move |this, cx| {
1051 let initial_project_snapshot = initial_project_snapshot.await;
1052 this.read_with(cx, |this, cx| SerializedThread {
1053 version: SerializedThread::VERSION.to_string(),
1054 summary: this.summary().or_default(),
1055 updated_at: this.updated_at(),
1056 messages: this
1057 .messages()
1058 .map(|message| SerializedMessage {
1059 id: message.id,
1060 role: message.role,
1061 segments: message
1062 .segments
1063 .iter()
1064 .map(|segment| match segment {
1065 MessageSegment::Text(text) => {
1066 SerializedMessageSegment::Text { text: text.clone() }
1067 }
1068 MessageSegment::Thinking { text, signature } => {
1069 SerializedMessageSegment::Thinking {
1070 text: text.clone(),
1071 signature: signature.clone(),
1072 }
1073 }
1074 MessageSegment::RedactedThinking(data) => {
1075 SerializedMessageSegment::RedactedThinking {
1076 data: data.clone(),
1077 }
1078 }
1079 })
1080 .collect(),
1081 tool_uses: this
1082 .tool_uses_for_message(message.id, cx)
1083 .into_iter()
1084 .map(|tool_use| SerializedToolUse {
1085 id: tool_use.id,
1086 name: tool_use.name,
1087 input: tool_use.input,
1088 })
1089 .collect(),
1090 tool_results: this
1091 .tool_results_for_message(message.id)
1092 .into_iter()
1093 .map(|tool_result| SerializedToolResult {
1094 tool_use_id: tool_result.tool_use_id.clone(),
1095 is_error: tool_result.is_error,
1096 content: tool_result.content.clone(),
1097 output: tool_result.output.clone(),
1098 })
1099 .collect(),
1100 context: message.loaded_context.text.clone(),
1101 creases: message
1102 .creases
1103 .iter()
1104 .map(|crease| SerializedCrease {
1105 start: crease.range.start,
1106 end: crease.range.end,
1107 icon_path: crease.metadata.icon_path.clone(),
1108 label: crease.metadata.label.clone(),
1109 })
1110 .collect(),
1111 })
1112 .collect(),
1113 initial_project_snapshot,
1114 cumulative_token_usage: this.cumulative_token_usage,
1115 request_token_usage: this.request_token_usage.clone(),
1116 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1117 exceeded_window_error: this.exceeded_window_error.clone(),
1118 model: this
1119 .configured_model
1120 .as_ref()
1121 .map(|model| SerializedLanguageModel {
1122 provider: model.provider.id().0.to_string(),
1123 model: model.model.id().0.to_string(),
1124 }),
1125 completion_mode: Some(this.completion_mode),
1126 })
1127 })
1128 }
1129
1130 pub fn remaining_turns(&self) -> u32 {
1131 self.remaining_turns
1132 }
1133
1134 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1135 self.remaining_turns = remaining_turns;
1136 }
1137
1138 pub fn send_to_model(
1139 &mut self,
1140 model: Arc<dyn LanguageModel>,
1141 window: Option<AnyWindowHandle>,
1142 cx: &mut Context<Self>,
1143 ) {
1144 if self.remaining_turns == 0 {
1145 return;
1146 }
1147
1148 self.remaining_turns -= 1;
1149
1150 let request = self.to_completion_request(model.clone(), cx);
1151
1152 self.stream_completion(request, model, window, cx);
1153 }
1154
1155 pub fn used_tools_since_last_user_message(&self) -> bool {
1156 for message in self.messages.iter().rev() {
1157 if self.tool_use.message_has_tool_results(message.id) {
1158 return true;
1159 } else if message.role == Role::User {
1160 return false;
1161 }
1162 }
1163
1164 false
1165 }
1166
1167 pub fn to_completion_request(
1168 &self,
1169 model: Arc<dyn LanguageModel>,
1170 cx: &mut Context<Self>,
1171 ) -> LanguageModelRequest {
1172 let mut request = LanguageModelRequest {
1173 thread_id: Some(self.id.to_string()),
1174 prompt_id: Some(self.last_prompt_id.to_string()),
1175 mode: None,
1176 messages: vec![],
1177 tools: Vec::new(),
1178 tool_choice: None,
1179 stop: Vec::new(),
1180 temperature: AssistantSettings::temperature_for_model(&model, cx),
1181 };
1182
1183 let available_tools = self.available_tools(cx, model.clone());
1184 let available_tool_names = available_tools
1185 .iter()
1186 .map(|tool| tool.name.clone())
1187 .collect();
1188
1189 let model_context = &ModelContext {
1190 available_tools: available_tool_names,
1191 };
1192
1193 if let Some(project_context) = self.project_context.borrow().as_ref() {
1194 match self
1195 .prompt_builder
1196 .generate_assistant_system_prompt(project_context, model_context)
1197 {
1198 Err(err) => {
1199 let message = format!("{err:?}").into();
1200 log::error!("{message}");
1201 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1202 header: "Error generating system prompt".into(),
1203 message,
1204 }));
1205 }
1206 Ok(system_prompt) => {
1207 request.messages.push(LanguageModelRequestMessage {
1208 role: Role::System,
1209 content: vec![MessageContent::Text(system_prompt)],
1210 cache: true,
1211 });
1212 }
1213 }
1214 } else {
1215 let message = "Context for system prompt unexpectedly not ready.".into();
1216 log::error!("{message}");
1217 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1218 header: "Error generating system prompt".into(),
1219 message,
1220 }));
1221 }
1222
1223 let mut message_ix_to_cache = None;
1224 for message in &self.messages {
1225 let mut request_message = LanguageModelRequestMessage {
1226 role: message.role,
1227 content: Vec::new(),
1228 cache: false,
1229 };
1230
1231 message
1232 .loaded_context
1233 .add_to_request_message(&mut request_message);
1234
1235 for segment in &message.segments {
1236 match segment {
1237 MessageSegment::Text(text) => {
1238 if !text.is_empty() {
1239 request_message
1240 .content
1241 .push(MessageContent::Text(text.into()));
1242 }
1243 }
1244 MessageSegment::Thinking { text, signature } => {
1245 if !text.is_empty() {
1246 request_message.content.push(MessageContent::Thinking {
1247 text: text.into(),
1248 signature: signature.clone(),
1249 });
1250 }
1251 }
1252 MessageSegment::RedactedThinking(data) => {
1253 request_message
1254 .content
1255 .push(MessageContent::RedactedThinking(data.clone()));
1256 }
1257 };
1258 }
1259
1260 let mut cache_message = true;
1261 let mut tool_results_message = LanguageModelRequestMessage {
1262 role: Role::User,
1263 content: Vec::new(),
1264 cache: false,
1265 };
1266 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1267 if let Some(tool_result) = tool_result {
1268 request_message
1269 .content
1270 .push(MessageContent::ToolUse(tool_use.clone()));
1271 tool_results_message
1272 .content
1273 .push(MessageContent::ToolResult(LanguageModelToolResult {
1274 tool_use_id: tool_use.id.clone(),
1275 tool_name: tool_result.tool_name.clone(),
1276 is_error: tool_result.is_error,
1277 content: if tool_result.content.is_empty() {
1278 // Surprisingly, the API fails if we return an empty string here.
1279 // It thinks we are sending a tool use without a tool result.
1280 "<Tool returned an empty string>".into()
1281 } else {
1282 tool_result.content.clone()
1283 },
1284 output: None,
1285 }));
1286 } else {
1287 cache_message = false;
1288 log::debug!(
1289 "skipped tool use {:?} because it is still pending",
1290 tool_use
1291 );
1292 }
1293 }
1294
1295 if cache_message {
1296 message_ix_to_cache = Some(request.messages.len());
1297 }
1298 request.messages.push(request_message);
1299
1300 if !tool_results_message.content.is_empty() {
1301 if cache_message {
1302 message_ix_to_cache = Some(request.messages.len());
1303 }
1304 request.messages.push(tool_results_message);
1305 }
1306 }
1307
1308 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1309 if let Some(message_ix_to_cache) = message_ix_to_cache {
1310 request.messages[message_ix_to_cache].cache = true;
1311 }
1312
1313 self.attached_tracked_files_state(&mut request.messages, cx);
1314
1315 request.tools = available_tools;
1316 request.mode = if model.supports_max_mode() {
1317 Some(self.completion_mode.into())
1318 } else {
1319 Some(CompletionMode::Normal.into())
1320 };
1321
1322 request
1323 }
1324
1325 fn to_summarize_request(
1326 &self,
1327 model: &Arc<dyn LanguageModel>,
1328 added_user_message: String,
1329 cx: &App,
1330 ) -> LanguageModelRequest {
1331 let mut request = LanguageModelRequest {
1332 thread_id: None,
1333 prompt_id: None,
1334 mode: None,
1335 messages: vec![],
1336 tools: Vec::new(),
1337 tool_choice: None,
1338 stop: Vec::new(),
1339 temperature: AssistantSettings::temperature_for_model(model, cx),
1340 };
1341
1342 for message in &self.messages {
1343 let mut request_message = LanguageModelRequestMessage {
1344 role: message.role,
1345 content: Vec::new(),
1346 cache: false,
1347 };
1348
1349 for segment in &message.segments {
1350 match segment {
1351 MessageSegment::Text(text) => request_message
1352 .content
1353 .push(MessageContent::Text(text.clone())),
1354 MessageSegment::Thinking { .. } => {}
1355 MessageSegment::RedactedThinking(_) => {}
1356 }
1357 }
1358
1359 if request_message.content.is_empty() {
1360 continue;
1361 }
1362
1363 request.messages.push(request_message);
1364 }
1365
1366 request.messages.push(LanguageModelRequestMessage {
1367 role: Role::User,
1368 content: vec![MessageContent::Text(added_user_message)],
1369 cache: false,
1370 });
1371
1372 request
1373 }
1374
1375 fn attached_tracked_files_state(
1376 &self,
1377 messages: &mut Vec<LanguageModelRequestMessage>,
1378 cx: &App,
1379 ) {
1380 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1381
1382 let mut stale_message = String::new();
1383
1384 let action_log = self.action_log.read(cx);
1385
1386 for stale_file in action_log.stale_buffers(cx) {
1387 let Some(file) = stale_file.read(cx).file() else {
1388 continue;
1389 };
1390
1391 if stale_message.is_empty() {
1392 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1393 }
1394
1395 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1396 }
1397
1398 let mut content = Vec::with_capacity(2);
1399
1400 if !stale_message.is_empty() {
1401 content.push(stale_message.into());
1402 }
1403
1404 if !content.is_empty() {
1405 let context_message = LanguageModelRequestMessage {
1406 role: Role::User,
1407 content,
1408 cache: false,
1409 };
1410
1411 messages.push(context_message);
1412 }
1413 }
1414
1415 pub fn stream_completion(
1416 &mut self,
1417 request: LanguageModelRequest,
1418 model: Arc<dyn LanguageModel>,
1419 window: Option<AnyWindowHandle>,
1420 cx: &mut Context<Self>,
1421 ) {
1422 self.tool_use_limit_reached = false;
1423
1424 let pending_completion_id = post_inc(&mut self.completion_count);
1425 let mut request_callback_parameters = if self.request_callback.is_some() {
1426 Some((request.clone(), Vec::new()))
1427 } else {
1428 None
1429 };
1430 let prompt_id = self.last_prompt_id.clone();
1431 let tool_use_metadata = ToolUseMetadata {
1432 model: model.clone(),
1433 thread_id: self.id.clone(),
1434 prompt_id: prompt_id.clone(),
1435 };
1436
1437 self.last_received_chunk_at = Some(Instant::now());
1438
1439 let task = cx.spawn(async move |thread, cx| {
1440 let stream_completion_future = model.stream_completion(request, &cx);
1441 let initial_token_usage =
1442 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1443 let stream_completion = async {
1444 let mut events = stream_completion_future.await?;
1445
1446 let mut stop_reason = StopReason::EndTurn;
1447 let mut current_token_usage = TokenUsage::default();
1448
1449 thread
1450 .update(cx, |_thread, cx| {
1451 cx.emit(ThreadEvent::NewRequest);
1452 })
1453 .ok();
1454
1455 let mut request_assistant_message_id = None;
1456
1457 while let Some(event) = events.next().await {
1458 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1459 response_events
1460 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1461 }
1462
1463 thread.update(cx, |thread, cx| {
1464 let event = match event {
1465 Ok(event) => event,
1466 Err(LanguageModelCompletionError::BadInputJson {
1467 id,
1468 tool_name,
1469 raw_input: invalid_input_json,
1470 json_parse_error,
1471 }) => {
1472 thread.receive_invalid_tool_json(
1473 id,
1474 tool_name,
1475 invalid_input_json,
1476 json_parse_error,
1477 window,
1478 cx,
1479 );
1480 return Ok(());
1481 }
1482 Err(LanguageModelCompletionError::Other(error)) => {
1483 return Err(error);
1484 }
1485 };
1486
1487 match event {
1488 LanguageModelCompletionEvent::StartMessage { .. } => {
1489 request_assistant_message_id =
1490 Some(thread.insert_assistant_message(
1491 vec![MessageSegment::Text(String::new())],
1492 cx,
1493 ));
1494 }
1495 LanguageModelCompletionEvent::Stop(reason) => {
1496 stop_reason = reason;
1497 }
1498 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1499 thread.update_token_usage_at_last_message(token_usage);
1500 thread.cumulative_token_usage = thread.cumulative_token_usage
1501 + token_usage
1502 - current_token_usage;
1503 current_token_usage = token_usage;
1504 }
1505 LanguageModelCompletionEvent::Text(chunk) => {
1506 thread.received_chunk();
1507
1508 cx.emit(ThreadEvent::ReceivedTextChunk);
1509 if let Some(last_message) = thread.messages.last_mut() {
1510 if last_message.role == Role::Assistant
1511 && !thread.tool_use.has_tool_results(last_message.id)
1512 {
1513 last_message.push_text(&chunk);
1514 cx.emit(ThreadEvent::StreamedAssistantText(
1515 last_message.id,
1516 chunk,
1517 ));
1518 } else {
1519 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1520 // of a new Assistant response.
1521 //
1522 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1523 // will result in duplicating the text of the chunk in the rendered Markdown.
1524 request_assistant_message_id =
1525 Some(thread.insert_assistant_message(
1526 vec![MessageSegment::Text(chunk.to_string())],
1527 cx,
1528 ));
1529 };
1530 }
1531 }
1532 LanguageModelCompletionEvent::Thinking {
1533 text: chunk,
1534 signature,
1535 } => {
1536 thread.received_chunk();
1537
1538 if let Some(last_message) = thread.messages.last_mut() {
1539 if last_message.role == Role::Assistant
1540 && !thread.tool_use.has_tool_results(last_message.id)
1541 {
1542 last_message.push_thinking(&chunk, signature);
1543 cx.emit(ThreadEvent::StreamedAssistantThinking(
1544 last_message.id,
1545 chunk,
1546 ));
1547 } else {
1548 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1549 // of a new Assistant response.
1550 //
1551 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1552 // will result in duplicating the text of the chunk in the rendered Markdown.
1553 request_assistant_message_id =
1554 Some(thread.insert_assistant_message(
1555 vec![MessageSegment::Thinking {
1556 text: chunk.to_string(),
1557 signature,
1558 }],
1559 cx,
1560 ));
1561 };
1562 }
1563 }
1564 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1565 let last_assistant_message_id = request_assistant_message_id
1566 .unwrap_or_else(|| {
1567 let new_assistant_message_id =
1568 thread.insert_assistant_message(vec![], cx);
1569 request_assistant_message_id =
1570 Some(new_assistant_message_id);
1571 new_assistant_message_id
1572 });
1573
1574 let tool_use_id = tool_use.id.clone();
1575 let streamed_input = if tool_use.is_input_complete {
1576 None
1577 } else {
1578 Some((&tool_use.input).clone())
1579 };
1580
1581 let ui_text = thread.tool_use.request_tool_use(
1582 last_assistant_message_id,
1583 tool_use,
1584 tool_use_metadata.clone(),
1585 cx,
1586 );
1587
1588 if let Some(input) = streamed_input {
1589 cx.emit(ThreadEvent::StreamedToolUse {
1590 tool_use_id,
1591 ui_text,
1592 input,
1593 });
1594 }
1595 }
1596 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1597 if let Some(completion) = thread
1598 .pending_completions
1599 .iter_mut()
1600 .find(|completion| completion.id == pending_completion_id)
1601 {
1602 match status_update {
1603 CompletionRequestStatus::Queued {
1604 position,
1605 } => {
1606 completion.queue_state = QueueState::Queued { position };
1607 }
1608 CompletionRequestStatus::Started => {
1609 completion.queue_state = QueueState::Started;
1610 }
1611 CompletionRequestStatus::Failed {
1612 code, message, request_id
1613 } => {
1614 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1615 }
1616 CompletionRequestStatus::UsageUpdated {
1617 amount, limit
1618 } => {
1619 let usage = RequestUsage { limit, amount: amount as i32 };
1620
1621 thread.last_usage = Some(usage);
1622 }
1623 CompletionRequestStatus::ToolUseLimitReached => {
1624 thread.tool_use_limit_reached = true;
1625 }
1626 }
1627 }
1628 }
1629 }
1630
1631 thread.touch_updated_at();
1632 cx.emit(ThreadEvent::StreamedCompletion);
1633 cx.notify();
1634
1635 thread.auto_capture_telemetry(cx);
1636 Ok(())
1637 })??;
1638
1639 smol::future::yield_now().await;
1640 }
1641
1642 thread.update(cx, |thread, cx| {
1643 thread.last_received_chunk_at = None;
1644 thread
1645 .pending_completions
1646 .retain(|completion| completion.id != pending_completion_id);
1647
1648 // If there is a response without tool use, summarize the message. Otherwise,
1649 // allow two tool uses before summarizing.
1650 if matches!(thread.summary, ThreadSummary::Pending)
1651 && thread.messages.len() >= 2
1652 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1653 {
1654 thread.summarize(cx);
1655 }
1656 })?;
1657
1658 anyhow::Ok(stop_reason)
1659 };
1660
1661 let result = stream_completion.await;
1662
1663 thread
1664 .update(cx, |thread, cx| {
1665 thread.finalize_pending_checkpoint(cx);
1666 match result.as_ref() {
1667 Ok(stop_reason) => match stop_reason {
1668 StopReason::ToolUse => {
1669 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1670 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1671 }
1672 StopReason::EndTurn | StopReason::MaxTokens => {
1673 thread.project.update(cx, |project, cx| {
1674 project.set_agent_location(None, cx);
1675 });
1676 }
1677 },
1678 Err(error) => {
1679 thread.project.update(cx, |project, cx| {
1680 project.set_agent_location(None, cx);
1681 });
1682
1683 if error.is::<PaymentRequiredError>() {
1684 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1685 } else if error.is::<MaxMonthlySpendReachedError>() {
1686 cx.emit(ThreadEvent::ShowError(
1687 ThreadError::MaxMonthlySpendReached,
1688 ));
1689 } else if let Some(error) =
1690 error.downcast_ref::<ModelRequestLimitReachedError>()
1691 {
1692 cx.emit(ThreadEvent::ShowError(
1693 ThreadError::ModelRequestLimitReached { plan: error.plan },
1694 ));
1695 } else if let Some(known_error) =
1696 error.downcast_ref::<LanguageModelKnownError>()
1697 {
1698 match known_error {
1699 LanguageModelKnownError::ContextWindowLimitExceeded {
1700 tokens,
1701 } => {
1702 thread.exceeded_window_error = Some(ExceededWindowError {
1703 model_id: model.id(),
1704 token_count: *tokens,
1705 });
1706 cx.notify();
1707 }
1708 }
1709 } else {
1710 let error_message = error
1711 .chain()
1712 .map(|err| err.to_string())
1713 .collect::<Vec<_>>()
1714 .join("\n");
1715 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1716 header: "Error interacting with language model".into(),
1717 message: SharedString::from(error_message.clone()),
1718 }));
1719 }
1720
1721 thread.cancel_last_completion(window, cx);
1722 }
1723 }
1724 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1725
1726 if let Some((request_callback, (request, response_events))) = thread
1727 .request_callback
1728 .as_mut()
1729 .zip(request_callback_parameters.as_ref())
1730 {
1731 request_callback(request, response_events);
1732 }
1733
1734 thread.auto_capture_telemetry(cx);
1735
1736 if let Ok(initial_usage) = initial_token_usage {
1737 let usage = thread.cumulative_token_usage - initial_usage;
1738
1739 telemetry::event!(
1740 "Assistant Thread Completion",
1741 thread_id = thread.id().to_string(),
1742 prompt_id = prompt_id,
1743 model = model.telemetry_id(),
1744 model_provider = model.provider_id().to_string(),
1745 input_tokens = usage.input_tokens,
1746 output_tokens = usage.output_tokens,
1747 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1748 cache_read_input_tokens = usage.cache_read_input_tokens,
1749 );
1750 }
1751 })
1752 .ok();
1753 });
1754
1755 self.pending_completions.push(PendingCompletion {
1756 id: pending_completion_id,
1757 queue_state: QueueState::Sending,
1758 _task: task,
1759 });
1760 }
1761
1762 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1763 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1764 println!("No thread summary model");
1765 return;
1766 };
1767
1768 if !model.provider.is_authenticated(cx) {
1769 return;
1770 }
1771
1772 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1773 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1774 If the conversation is about a specific subject, include it in the title. \
1775 Be descriptive. DO NOT speak in the first person.";
1776
1777 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1778
1779 self.summary = ThreadSummary::Generating;
1780
1781 self.pending_summary = cx.spawn(async move |this, cx| {
1782 let result = async {
1783 let mut messages = model.model.stream_completion(request, &cx).await?;
1784
1785 let mut new_summary = String::new();
1786 while let Some(event) = messages.next().await {
1787 let Ok(event) = event else {
1788 continue;
1789 };
1790 let text = match event {
1791 LanguageModelCompletionEvent::Text(text) => text,
1792 LanguageModelCompletionEvent::StatusUpdate(
1793 CompletionRequestStatus::UsageUpdated { amount, limit },
1794 ) => {
1795 this.update(cx, |thread, _cx| {
1796 thread.last_usage = Some(RequestUsage {
1797 limit,
1798 amount: amount as i32,
1799 });
1800 })?;
1801 continue;
1802 }
1803 _ => continue,
1804 };
1805
1806 let mut lines = text.lines();
1807 new_summary.extend(lines.next());
1808
1809 // Stop if the LLM generated multiple lines.
1810 if lines.next().is_some() {
1811 break;
1812 }
1813 }
1814
1815 anyhow::Ok(new_summary)
1816 }
1817 .await;
1818
1819 this.update(cx, |this, cx| {
1820 match result {
1821 Ok(new_summary) => {
1822 if new_summary.is_empty() {
1823 this.summary = ThreadSummary::Error;
1824 } else {
1825 this.summary = ThreadSummary::Ready(new_summary.into());
1826 }
1827 }
1828 Err(err) => {
1829 this.summary = ThreadSummary::Error;
1830 log::error!("Failed to generate thread summary: {}", err);
1831 }
1832 }
1833 cx.emit(ThreadEvent::SummaryGenerated);
1834 })
1835 .log_err()?;
1836
1837 Some(())
1838 });
1839 }
1840
1841 pub fn start_generating_detailed_summary_if_needed(
1842 &mut self,
1843 thread_store: WeakEntity<ThreadStore>,
1844 cx: &mut Context<Self>,
1845 ) {
1846 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1847 return;
1848 };
1849
1850 match &*self.detailed_summary_rx.borrow() {
1851 DetailedSummaryState::Generating { message_id, .. }
1852 | DetailedSummaryState::Generated { message_id, .. }
1853 if *message_id == last_message_id =>
1854 {
1855 // Already up-to-date
1856 return;
1857 }
1858 _ => {}
1859 }
1860
1861 let Some(ConfiguredModel { model, provider }) =
1862 LanguageModelRegistry::read_global(cx).thread_summary_model()
1863 else {
1864 return;
1865 };
1866
1867 if !provider.is_authenticated(cx) {
1868 return;
1869 }
1870
1871 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1872 1. A brief overview of what was discussed\n\
1873 2. Key facts or information discovered\n\
1874 3. Outcomes or conclusions reached\n\
1875 4. Any action items or next steps if any\n\
1876 Format it in Markdown with headings and bullet points.";
1877
1878 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1879
1880 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1881 message_id: last_message_id,
1882 };
1883
1884 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1885 // be better to allow the old task to complete, but this would require logic for choosing
1886 // which result to prefer (the old task could complete after the new one, resulting in a
1887 // stale summary).
1888 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1889 let stream = model.stream_completion_text(request, &cx);
1890 let Some(mut messages) = stream.await.log_err() else {
1891 thread
1892 .update(cx, |thread, _cx| {
1893 *thread.detailed_summary_tx.borrow_mut() =
1894 DetailedSummaryState::NotGenerated;
1895 })
1896 .ok()?;
1897 return None;
1898 };
1899
1900 let mut new_detailed_summary = String::new();
1901
1902 while let Some(chunk) = messages.stream.next().await {
1903 if let Some(chunk) = chunk.log_err() {
1904 new_detailed_summary.push_str(&chunk);
1905 }
1906 }
1907
1908 thread
1909 .update(cx, |thread, _cx| {
1910 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1911 text: new_detailed_summary.into(),
1912 message_id: last_message_id,
1913 };
1914 })
1915 .ok()?;
1916
1917 // Save thread so its summary can be reused later
1918 if let Some(thread) = thread.upgrade() {
1919 if let Ok(Ok(save_task)) = cx.update(|cx| {
1920 thread_store
1921 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1922 }) {
1923 save_task.await.log_err();
1924 }
1925 }
1926
1927 Some(())
1928 });
1929 }
1930
1931 pub async fn wait_for_detailed_summary_or_text(
1932 this: &Entity<Self>,
1933 cx: &mut AsyncApp,
1934 ) -> Option<SharedString> {
1935 let mut detailed_summary_rx = this
1936 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1937 .ok()?;
1938 loop {
1939 match detailed_summary_rx.recv().await? {
1940 DetailedSummaryState::Generating { .. } => {}
1941 DetailedSummaryState::NotGenerated => {
1942 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1943 }
1944 DetailedSummaryState::Generated { text, .. } => return Some(text),
1945 }
1946 }
1947 }
1948
1949 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1950 self.detailed_summary_rx
1951 .borrow()
1952 .text()
1953 .unwrap_or_else(|| self.text().into())
1954 }
1955
1956 pub fn is_generating_detailed_summary(&self) -> bool {
1957 matches!(
1958 &*self.detailed_summary_rx.borrow(),
1959 DetailedSummaryState::Generating { .. }
1960 )
1961 }
1962
1963 pub fn use_pending_tools(
1964 &mut self,
1965 window: Option<AnyWindowHandle>,
1966 cx: &mut Context<Self>,
1967 model: Arc<dyn LanguageModel>,
1968 ) -> Vec<PendingToolUse> {
1969 self.auto_capture_telemetry(cx);
1970 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1971 let pending_tool_uses = self
1972 .tool_use
1973 .pending_tool_uses()
1974 .into_iter()
1975 .filter(|tool_use| tool_use.status.is_idle())
1976 .cloned()
1977 .collect::<Vec<_>>();
1978
1979 for tool_use in pending_tool_uses.iter() {
1980 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1981 if tool.needs_confirmation(&tool_use.input, cx)
1982 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1983 {
1984 self.tool_use.confirm_tool_use(
1985 tool_use.id.clone(),
1986 tool_use.ui_text.clone(),
1987 tool_use.input.clone(),
1988 request.clone(),
1989 tool,
1990 );
1991 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1992 } else {
1993 self.run_tool(
1994 tool_use.id.clone(),
1995 tool_use.ui_text.clone(),
1996 tool_use.input.clone(),
1997 request.clone(),
1998 tool,
1999 model.clone(),
2000 window,
2001 cx,
2002 );
2003 }
2004 } else {
2005 self.handle_hallucinated_tool_use(
2006 tool_use.id.clone(),
2007 tool_use.name.clone(),
2008 window,
2009 cx,
2010 );
2011 }
2012 }
2013
2014 pending_tool_uses
2015 }
2016
2017 pub fn handle_hallucinated_tool_use(
2018 &mut self,
2019 tool_use_id: LanguageModelToolUseId,
2020 hallucinated_tool_name: Arc<str>,
2021 window: Option<AnyWindowHandle>,
2022 cx: &mut Context<Thread>,
2023 ) {
2024 let available_tools = self.tools.read(cx).enabled_tools(cx);
2025
2026 let tool_list = available_tools
2027 .iter()
2028 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2029 .collect::<Vec<_>>()
2030 .join("\n");
2031
2032 let error_message = format!(
2033 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2034 hallucinated_tool_name, tool_list
2035 );
2036
2037 let pending_tool_use = self.tool_use.insert_tool_output(
2038 tool_use_id.clone(),
2039 hallucinated_tool_name,
2040 Err(anyhow!("Missing tool call: {error_message}")),
2041 self.configured_model.as_ref(),
2042 );
2043
2044 cx.emit(ThreadEvent::MissingToolUse {
2045 tool_use_id: tool_use_id.clone(),
2046 ui_text: error_message.into(),
2047 });
2048
2049 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2050 }
2051
2052 pub fn receive_invalid_tool_json(
2053 &mut self,
2054 tool_use_id: LanguageModelToolUseId,
2055 tool_name: Arc<str>,
2056 invalid_json: Arc<str>,
2057 error: String,
2058 window: Option<AnyWindowHandle>,
2059 cx: &mut Context<Thread>,
2060 ) {
2061 log::error!("The model returned invalid input JSON: {invalid_json}");
2062
2063 let pending_tool_use = self.tool_use.insert_tool_output(
2064 tool_use_id.clone(),
2065 tool_name,
2066 Err(anyhow!("Error parsing input JSON: {error}")),
2067 self.configured_model.as_ref(),
2068 );
2069 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2070 pending_tool_use.ui_text.clone()
2071 } else {
2072 log::error!(
2073 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2074 );
2075 format!("Unknown tool {}", tool_use_id).into()
2076 };
2077
2078 cx.emit(ThreadEvent::InvalidToolInput {
2079 tool_use_id: tool_use_id.clone(),
2080 ui_text,
2081 invalid_input_json: invalid_json,
2082 });
2083
2084 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2085 }
2086
2087 pub fn run_tool(
2088 &mut self,
2089 tool_use_id: LanguageModelToolUseId,
2090 ui_text: impl Into<SharedString>,
2091 input: serde_json::Value,
2092 request: Arc<LanguageModelRequest>,
2093 tool: Arc<dyn Tool>,
2094 model: Arc<dyn LanguageModel>,
2095 window: Option<AnyWindowHandle>,
2096 cx: &mut Context<Thread>,
2097 ) {
2098 let task =
2099 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2100 self.tool_use
2101 .run_pending_tool(tool_use_id, ui_text.into(), task);
2102 }
2103
2104 fn spawn_tool_use(
2105 &mut self,
2106 tool_use_id: LanguageModelToolUseId,
2107 request: Arc<LanguageModelRequest>,
2108 input: serde_json::Value,
2109 tool: Arc<dyn Tool>,
2110 model: Arc<dyn LanguageModel>,
2111 window: Option<AnyWindowHandle>,
2112 cx: &mut Context<Thread>,
2113 ) -> Task<()> {
2114 let tool_name: Arc<str> = tool.name().into();
2115
2116 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2117 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2118 } else {
2119 tool.run(
2120 input,
2121 request,
2122 self.project.clone(),
2123 self.action_log.clone(),
2124 model,
2125 window,
2126 cx,
2127 )
2128 };
2129
2130 // Store the card separately if it exists
2131 if let Some(card) = tool_result.card.clone() {
2132 self.tool_use
2133 .insert_tool_result_card(tool_use_id.clone(), card);
2134 }
2135
2136 cx.spawn({
2137 async move |thread: WeakEntity<Thread>, cx| {
2138 let output = tool_result.output.await;
2139
2140 thread
2141 .update(cx, |thread, cx| {
2142 let pending_tool_use = thread.tool_use.insert_tool_output(
2143 tool_use_id.clone(),
2144 tool_name,
2145 output,
2146 thread.configured_model.as_ref(),
2147 );
2148 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2149 })
2150 .ok();
2151 }
2152 })
2153 }
2154
2155 fn tool_finished(
2156 &mut self,
2157 tool_use_id: LanguageModelToolUseId,
2158 pending_tool_use: Option<PendingToolUse>,
2159 canceled: bool,
2160 window: Option<AnyWindowHandle>,
2161 cx: &mut Context<Self>,
2162 ) {
2163 if self.all_tools_finished() {
2164 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2165 if !canceled {
2166 self.send_to_model(model.clone(), window, cx);
2167 }
2168 self.auto_capture_telemetry(cx);
2169 }
2170 }
2171
2172 cx.emit(ThreadEvent::ToolFinished {
2173 tool_use_id,
2174 pending_tool_use,
2175 });
2176 }
2177
2178 /// Cancels the last pending completion, if there are any pending.
2179 ///
2180 /// Returns whether a completion was canceled.
2181 pub fn cancel_last_completion(
2182 &mut self,
2183 window: Option<AnyWindowHandle>,
2184 cx: &mut Context<Self>,
2185 ) -> bool {
2186 let mut canceled = self.pending_completions.pop().is_some();
2187
2188 for pending_tool_use in self.tool_use.cancel_pending() {
2189 canceled = true;
2190 self.tool_finished(
2191 pending_tool_use.id.clone(),
2192 Some(pending_tool_use),
2193 true,
2194 window,
2195 cx,
2196 );
2197 }
2198
2199 self.finalize_pending_checkpoint(cx);
2200
2201 if canceled {
2202 cx.emit(ThreadEvent::CompletionCanceled);
2203 }
2204
2205 canceled
2206 }
2207
2208 /// Signals that any in-progress editing should be canceled.
2209 ///
2210 /// This method is used to notify listeners (like ActiveThread) that
2211 /// they should cancel any editing operations.
2212 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2213 cx.emit(ThreadEvent::CancelEditing);
2214 }
2215
2216 pub fn feedback(&self) -> Option<ThreadFeedback> {
2217 self.feedback
2218 }
2219
2220 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2221 self.message_feedback.get(&message_id).copied()
2222 }
2223
2224 pub fn report_message_feedback(
2225 &mut self,
2226 message_id: MessageId,
2227 feedback: ThreadFeedback,
2228 cx: &mut Context<Self>,
2229 ) -> Task<Result<()>> {
2230 if self.message_feedback.get(&message_id) == Some(&feedback) {
2231 return Task::ready(Ok(()));
2232 }
2233
2234 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2235 let serialized_thread = self.serialize(cx);
2236 let thread_id = self.id().clone();
2237 let client = self.project.read(cx).client();
2238
2239 let enabled_tool_names: Vec<String> = self
2240 .tools()
2241 .read(cx)
2242 .enabled_tools(cx)
2243 .iter()
2244 .map(|tool| tool.name())
2245 .collect();
2246
2247 self.message_feedback.insert(message_id, feedback);
2248
2249 cx.notify();
2250
2251 let message_content = self
2252 .message(message_id)
2253 .map(|msg| msg.to_string())
2254 .unwrap_or_default();
2255
2256 cx.background_spawn(async move {
2257 let final_project_snapshot = final_project_snapshot.await;
2258 let serialized_thread = serialized_thread.await?;
2259 let thread_data =
2260 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2261
2262 let rating = match feedback {
2263 ThreadFeedback::Positive => "positive",
2264 ThreadFeedback::Negative => "negative",
2265 };
2266 telemetry::event!(
2267 "Assistant Thread Rated",
2268 rating,
2269 thread_id,
2270 enabled_tool_names,
2271 message_id = message_id.0,
2272 message_content,
2273 thread_data,
2274 final_project_snapshot
2275 );
2276 client.telemetry().flush_events().await;
2277
2278 Ok(())
2279 })
2280 }
2281
2282 pub fn report_feedback(
2283 &mut self,
2284 feedback: ThreadFeedback,
2285 cx: &mut Context<Self>,
2286 ) -> Task<Result<()>> {
2287 let last_assistant_message_id = self
2288 .messages
2289 .iter()
2290 .rev()
2291 .find(|msg| msg.role == Role::Assistant)
2292 .map(|msg| msg.id);
2293
2294 if let Some(message_id) = last_assistant_message_id {
2295 self.report_message_feedback(message_id, feedback, cx)
2296 } else {
2297 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2298 let serialized_thread = self.serialize(cx);
2299 let thread_id = self.id().clone();
2300 let client = self.project.read(cx).client();
2301 self.feedback = Some(feedback);
2302 cx.notify();
2303
2304 cx.background_spawn(async move {
2305 let final_project_snapshot = final_project_snapshot.await;
2306 let serialized_thread = serialized_thread.await?;
2307 let thread_data = serde_json::to_value(serialized_thread)
2308 .unwrap_or_else(|_| serde_json::Value::Null);
2309
2310 let rating = match feedback {
2311 ThreadFeedback::Positive => "positive",
2312 ThreadFeedback::Negative => "negative",
2313 };
2314 telemetry::event!(
2315 "Assistant Thread Rated",
2316 rating,
2317 thread_id,
2318 thread_data,
2319 final_project_snapshot
2320 );
2321 client.telemetry().flush_events().await;
2322
2323 Ok(())
2324 })
2325 }
2326 }
2327
2328 /// Create a snapshot of the current project state including git information and unsaved buffers.
2329 fn project_snapshot(
2330 project: Entity<Project>,
2331 cx: &mut Context<Self>,
2332 ) -> Task<Arc<ProjectSnapshot>> {
2333 let git_store = project.read(cx).git_store().clone();
2334 let worktree_snapshots: Vec<_> = project
2335 .read(cx)
2336 .visible_worktrees(cx)
2337 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2338 .collect();
2339
2340 cx.spawn(async move |_, cx| {
2341 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2342
2343 let mut unsaved_buffers = Vec::new();
2344 cx.update(|app_cx| {
2345 let buffer_store = project.read(app_cx).buffer_store();
2346 for buffer_handle in buffer_store.read(app_cx).buffers() {
2347 let buffer = buffer_handle.read(app_cx);
2348 if buffer.is_dirty() {
2349 if let Some(file) = buffer.file() {
2350 let path = file.path().to_string_lossy().to_string();
2351 unsaved_buffers.push(path);
2352 }
2353 }
2354 }
2355 })
2356 .ok();
2357
2358 Arc::new(ProjectSnapshot {
2359 worktree_snapshots,
2360 unsaved_buffer_paths: unsaved_buffers,
2361 timestamp: Utc::now(),
2362 })
2363 })
2364 }
2365
2366 fn worktree_snapshot(
2367 worktree: Entity<project::Worktree>,
2368 git_store: Entity<GitStore>,
2369 cx: &App,
2370 ) -> Task<WorktreeSnapshot> {
2371 cx.spawn(async move |cx| {
2372 // Get worktree path and snapshot
2373 let worktree_info = cx.update(|app_cx| {
2374 let worktree = worktree.read(app_cx);
2375 let path = worktree.abs_path().to_string_lossy().to_string();
2376 let snapshot = worktree.snapshot();
2377 (path, snapshot)
2378 });
2379
2380 let Ok((worktree_path, _snapshot)) = worktree_info else {
2381 return WorktreeSnapshot {
2382 worktree_path: String::new(),
2383 git_state: None,
2384 };
2385 };
2386
2387 let git_state = git_store
2388 .update(cx, |git_store, cx| {
2389 git_store
2390 .repositories()
2391 .values()
2392 .find(|repo| {
2393 repo.read(cx)
2394 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2395 .is_some()
2396 })
2397 .cloned()
2398 })
2399 .ok()
2400 .flatten()
2401 .map(|repo| {
2402 repo.update(cx, |repo, _| {
2403 let current_branch =
2404 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2405 repo.send_job(None, |state, _| async move {
2406 let RepositoryState::Local { backend, .. } = state else {
2407 return GitState {
2408 remote_url: None,
2409 head_sha: None,
2410 current_branch,
2411 diff: None,
2412 };
2413 };
2414
2415 let remote_url = backend.remote_url("origin");
2416 let head_sha = backend.head_sha().await;
2417 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2418
2419 GitState {
2420 remote_url,
2421 head_sha,
2422 current_branch,
2423 diff,
2424 }
2425 })
2426 })
2427 });
2428
2429 let git_state = match git_state {
2430 Some(git_state) => match git_state.ok() {
2431 Some(git_state) => git_state.await.ok(),
2432 None => None,
2433 },
2434 None => None,
2435 };
2436
2437 WorktreeSnapshot {
2438 worktree_path,
2439 git_state,
2440 }
2441 })
2442 }
2443
2444 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2445 let mut markdown = Vec::new();
2446
2447 let summary = self.summary().or_default();
2448 writeln!(markdown, "# {summary}\n")?;
2449
2450 for message in self.messages() {
2451 writeln!(
2452 markdown,
2453 "## {role}\n",
2454 role = match message.role {
2455 Role::User => "User",
2456 Role::Assistant => "Agent",
2457 Role::System => "System",
2458 }
2459 )?;
2460
2461 if !message.loaded_context.text.is_empty() {
2462 writeln!(markdown, "{}", message.loaded_context.text)?;
2463 }
2464
2465 if !message.loaded_context.images.is_empty() {
2466 writeln!(
2467 markdown,
2468 "\n{} images attached as context.\n",
2469 message.loaded_context.images.len()
2470 )?;
2471 }
2472
2473 for segment in &message.segments {
2474 match segment {
2475 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2476 MessageSegment::Thinking { text, .. } => {
2477 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2478 }
2479 MessageSegment::RedactedThinking(_) => {}
2480 }
2481 }
2482
2483 for tool_use in self.tool_uses_for_message(message.id, cx) {
2484 writeln!(
2485 markdown,
2486 "**Use Tool: {} ({})**",
2487 tool_use.name, tool_use.id
2488 )?;
2489 writeln!(markdown, "```json")?;
2490 writeln!(
2491 markdown,
2492 "{}",
2493 serde_json::to_string_pretty(&tool_use.input)?
2494 )?;
2495 writeln!(markdown, "```")?;
2496 }
2497
2498 for tool_result in self.tool_results_for_message(message.id) {
2499 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2500 if tool_result.is_error {
2501 write!(markdown, " (Error)")?;
2502 }
2503
2504 writeln!(markdown, "**\n")?;
2505 writeln!(markdown, "{}", tool_result.content)?;
2506 if let Some(output) = tool_result.output.as_ref() {
2507 writeln!(
2508 markdown,
2509 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2510 serde_json::to_string_pretty(output)?
2511 )?;
2512 }
2513 }
2514 }
2515
2516 Ok(String::from_utf8_lossy(&markdown).to_string())
2517 }
2518
2519 pub fn keep_edits_in_range(
2520 &mut self,
2521 buffer: Entity<language::Buffer>,
2522 buffer_range: Range<language::Anchor>,
2523 cx: &mut Context<Self>,
2524 ) {
2525 self.action_log.update(cx, |action_log, cx| {
2526 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2527 });
2528 }
2529
2530 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2531 self.action_log
2532 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2533 }
2534
2535 pub fn reject_edits_in_ranges(
2536 &mut self,
2537 buffer: Entity<language::Buffer>,
2538 buffer_ranges: Vec<Range<language::Anchor>>,
2539 cx: &mut Context<Self>,
2540 ) -> Task<Result<()>> {
2541 self.action_log.update(cx, |action_log, cx| {
2542 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2543 })
2544 }
2545
2546 pub fn action_log(&self) -> &Entity<ActionLog> {
2547 &self.action_log
2548 }
2549
2550 pub fn project(&self) -> &Entity<Project> {
2551 &self.project
2552 }
2553
2554 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2555 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2556 return;
2557 }
2558
2559 let now = Instant::now();
2560 if let Some(last) = self.last_auto_capture_at {
2561 if now.duration_since(last).as_secs() < 10 {
2562 return;
2563 }
2564 }
2565
2566 self.last_auto_capture_at = Some(now);
2567
2568 let thread_id = self.id().clone();
2569 let github_login = self
2570 .project
2571 .read(cx)
2572 .user_store()
2573 .read(cx)
2574 .current_user()
2575 .map(|user| user.github_login.clone());
2576 let client = self.project.read(cx).client().clone();
2577 let serialize_task = self.serialize(cx);
2578
2579 cx.background_executor()
2580 .spawn(async move {
2581 if let Ok(serialized_thread) = serialize_task.await {
2582 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2583 telemetry::event!(
2584 "Agent Thread Auto-Captured",
2585 thread_id = thread_id.to_string(),
2586 thread_data = thread_data,
2587 auto_capture_reason = "tracked_user",
2588 github_login = github_login
2589 );
2590
2591 client.telemetry().flush_events().await;
2592 }
2593 }
2594 })
2595 .detach();
2596 }
2597
2598 pub fn cumulative_token_usage(&self) -> TokenUsage {
2599 self.cumulative_token_usage
2600 }
2601
2602 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2603 let Some(model) = self.configured_model.as_ref() else {
2604 return TotalTokenUsage::default();
2605 };
2606
2607 let max = model.model.max_token_count();
2608
2609 let index = self
2610 .messages
2611 .iter()
2612 .position(|msg| msg.id == message_id)
2613 .unwrap_or(0);
2614
2615 if index == 0 {
2616 return TotalTokenUsage { total: 0, max };
2617 }
2618
2619 let token_usage = &self
2620 .request_token_usage
2621 .get(index - 1)
2622 .cloned()
2623 .unwrap_or_default();
2624
2625 TotalTokenUsage {
2626 total: token_usage.total_tokens() as usize,
2627 max,
2628 }
2629 }
2630
2631 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2632 let model = self.configured_model.as_ref()?;
2633
2634 let max = model.model.max_token_count();
2635
2636 if let Some(exceeded_error) = &self.exceeded_window_error {
2637 if model.model.id() == exceeded_error.model_id {
2638 return Some(TotalTokenUsage {
2639 total: exceeded_error.token_count,
2640 max,
2641 });
2642 }
2643 }
2644
2645 let total = self
2646 .token_usage_at_last_message()
2647 .unwrap_or_default()
2648 .total_tokens() as usize;
2649
2650 Some(TotalTokenUsage { total, max })
2651 }
2652
2653 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2654 self.request_token_usage
2655 .get(self.messages.len().saturating_sub(1))
2656 .or_else(|| self.request_token_usage.last())
2657 .cloned()
2658 }
2659
2660 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2661 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2662 self.request_token_usage
2663 .resize(self.messages.len(), placeholder);
2664
2665 if let Some(last) = self.request_token_usage.last_mut() {
2666 *last = token_usage;
2667 }
2668 }
2669
2670 pub fn deny_tool_use(
2671 &mut self,
2672 tool_use_id: LanguageModelToolUseId,
2673 tool_name: Arc<str>,
2674 window: Option<AnyWindowHandle>,
2675 cx: &mut Context<Self>,
2676 ) {
2677 let err = Err(anyhow::anyhow!(
2678 "Permission to run tool action denied by user"
2679 ));
2680
2681 self.tool_use.insert_tool_output(
2682 tool_use_id.clone(),
2683 tool_name,
2684 err,
2685 self.configured_model.as_ref(),
2686 );
2687 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2688 }
2689}
2690
2691#[derive(Debug, Clone, Error)]
2692pub enum ThreadError {
2693 #[error("Payment required")]
2694 PaymentRequired,
2695 #[error("Max monthly spend reached")]
2696 MaxMonthlySpendReached,
2697 #[error("Model request limit reached")]
2698 ModelRequestLimitReached { plan: Plan },
2699 #[error("Message {header}: {message}")]
2700 Message {
2701 header: SharedString,
2702 message: SharedString,
2703 },
2704}
2705
2706#[derive(Debug, Clone)]
2707pub enum ThreadEvent {
2708 ShowError(ThreadError),
2709 StreamedCompletion,
2710 ReceivedTextChunk,
2711 NewRequest,
2712 StreamedAssistantText(MessageId, String),
2713 StreamedAssistantThinking(MessageId, String),
2714 StreamedToolUse {
2715 tool_use_id: LanguageModelToolUseId,
2716 ui_text: Arc<str>,
2717 input: serde_json::Value,
2718 },
2719 MissingToolUse {
2720 tool_use_id: LanguageModelToolUseId,
2721 ui_text: Arc<str>,
2722 },
2723 InvalidToolInput {
2724 tool_use_id: LanguageModelToolUseId,
2725 ui_text: Arc<str>,
2726 invalid_input_json: Arc<str>,
2727 },
2728 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2729 MessageAdded(MessageId),
2730 MessageEdited(MessageId),
2731 MessageDeleted(MessageId),
2732 SummaryGenerated,
2733 SummaryChanged,
2734 UsePendingTools {
2735 tool_uses: Vec<PendingToolUse>,
2736 },
2737 ToolFinished {
2738 #[allow(unused)]
2739 tool_use_id: LanguageModelToolUseId,
2740 /// The pending tool use that corresponds to this tool.
2741 pending_tool_use: Option<PendingToolUse>,
2742 },
2743 CheckpointChanged,
2744 ToolConfirmationNeeded,
2745 CancelEditing,
2746 CompletionCanceled,
2747}
2748
2749impl EventEmitter<ThreadEvent> for Thread {}
2750
2751struct PendingCompletion {
2752 id: usize,
2753 queue_state: QueueState,
2754 _task: Task<()>,
2755}
2756
2757#[cfg(test)]
2758mod tests {
2759 use super::*;
2760 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2761 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2762 use assistant_tool::ToolRegistry;
2763 use editor::EditorSettings;
2764 use gpui::TestAppContext;
2765 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2766 use project::{FakeFs, Project};
2767 use prompt_store::PromptBuilder;
2768 use serde_json::json;
2769 use settings::{Settings, SettingsStore};
2770 use std::sync::Arc;
2771 use theme::ThemeSettings;
2772 use util::path;
2773 use workspace::Workspace;
2774
2775 #[gpui::test]
2776 async fn test_message_with_context(cx: &mut TestAppContext) {
2777 init_test_settings(cx);
2778
2779 let project = create_test_project(
2780 cx,
2781 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2782 )
2783 .await;
2784
2785 let (_workspace, _thread_store, thread, context_store, model) =
2786 setup_test_environment(cx, project.clone()).await;
2787
2788 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2789 .await
2790 .unwrap();
2791
2792 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2793 let loaded_context = cx
2794 .update(|cx| load_context(vec![context], &project, &None, cx))
2795 .await;
2796
2797 // Insert user message with context
2798 let message_id = thread.update(cx, |thread, cx| {
2799 thread.insert_user_message(
2800 "Please explain this code",
2801 loaded_context,
2802 None,
2803 Vec::new(),
2804 cx,
2805 )
2806 });
2807
2808 // Check content and context in message object
2809 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2810
2811 // Use different path format strings based on platform for the test
2812 #[cfg(windows)]
2813 let path_part = r"test\code.rs";
2814 #[cfg(not(windows))]
2815 let path_part = "test/code.rs";
2816
2817 let expected_context = format!(
2818 r#"
2819<context>
2820The following items were attached by the user. They are up-to-date and don't need to be re-read.
2821
2822<files>
2823```rs {path_part}
2824fn main() {{
2825 println!("Hello, world!");
2826}}
2827```
2828</files>
2829</context>
2830"#
2831 );
2832
2833 assert_eq!(message.role, Role::User);
2834 assert_eq!(message.segments.len(), 1);
2835 assert_eq!(
2836 message.segments[0],
2837 MessageSegment::Text("Please explain this code".to_string())
2838 );
2839 assert_eq!(message.loaded_context.text, expected_context);
2840
2841 // Check message in request
2842 let request = thread.update(cx, |thread, cx| {
2843 thread.to_completion_request(model.clone(), cx)
2844 });
2845
2846 assert_eq!(request.messages.len(), 2);
2847 let expected_full_message = format!("{}Please explain this code", expected_context);
2848 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2849 }
2850
2851 #[gpui::test]
2852 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2853 init_test_settings(cx);
2854
2855 let project = create_test_project(
2856 cx,
2857 json!({
2858 "file1.rs": "fn function1() {}\n",
2859 "file2.rs": "fn function2() {}\n",
2860 "file3.rs": "fn function3() {}\n",
2861 "file4.rs": "fn function4() {}\n",
2862 }),
2863 )
2864 .await;
2865
2866 let (_, _thread_store, thread, context_store, model) =
2867 setup_test_environment(cx, project.clone()).await;
2868
2869 // First message with context 1
2870 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2871 .await
2872 .unwrap();
2873 let new_contexts = context_store.update(cx, |store, cx| {
2874 store.new_context_for_thread(thread.read(cx), None)
2875 });
2876 assert_eq!(new_contexts.len(), 1);
2877 let loaded_context = cx
2878 .update(|cx| load_context(new_contexts, &project, &None, cx))
2879 .await;
2880 let message1_id = thread.update(cx, |thread, cx| {
2881 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2882 });
2883
2884 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2885 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2886 .await
2887 .unwrap();
2888 let new_contexts = context_store.update(cx, |store, cx| {
2889 store.new_context_for_thread(thread.read(cx), None)
2890 });
2891 assert_eq!(new_contexts.len(), 1);
2892 let loaded_context = cx
2893 .update(|cx| load_context(new_contexts, &project, &None, cx))
2894 .await;
2895 let message2_id = thread.update(cx, |thread, cx| {
2896 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2897 });
2898
2899 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2900 //
2901 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2902 .await
2903 .unwrap();
2904 let new_contexts = context_store.update(cx, |store, cx| {
2905 store.new_context_for_thread(thread.read(cx), None)
2906 });
2907 assert_eq!(new_contexts.len(), 1);
2908 let loaded_context = cx
2909 .update(|cx| load_context(new_contexts, &project, &None, cx))
2910 .await;
2911 let message3_id = thread.update(cx, |thread, cx| {
2912 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2913 });
2914
2915 // Check what contexts are included in each message
2916 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2917 (
2918 thread.message(message1_id).unwrap().clone(),
2919 thread.message(message2_id).unwrap().clone(),
2920 thread.message(message3_id).unwrap().clone(),
2921 )
2922 });
2923
2924 // First message should include context 1
2925 assert!(message1.loaded_context.text.contains("file1.rs"));
2926
2927 // Second message should include only context 2 (not 1)
2928 assert!(!message2.loaded_context.text.contains("file1.rs"));
2929 assert!(message2.loaded_context.text.contains("file2.rs"));
2930
2931 // Third message should include only context 3 (not 1 or 2)
2932 assert!(!message3.loaded_context.text.contains("file1.rs"));
2933 assert!(!message3.loaded_context.text.contains("file2.rs"));
2934 assert!(message3.loaded_context.text.contains("file3.rs"));
2935
2936 // Check entire request to make sure all contexts are properly included
2937 let request = thread.update(cx, |thread, cx| {
2938 thread.to_completion_request(model.clone(), cx)
2939 });
2940
2941 // The request should contain all 3 messages
2942 assert_eq!(request.messages.len(), 4);
2943
2944 // Check that the contexts are properly formatted in each message
2945 assert!(request.messages[1].string_contents().contains("file1.rs"));
2946 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2947 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2948
2949 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2950 assert!(request.messages[2].string_contents().contains("file2.rs"));
2951 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2952
2953 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2954 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2955 assert!(request.messages[3].string_contents().contains("file3.rs"));
2956
2957 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2958 .await
2959 .unwrap();
2960 let new_contexts = context_store.update(cx, |store, cx| {
2961 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2962 });
2963 assert_eq!(new_contexts.len(), 3);
2964 let loaded_context = cx
2965 .update(|cx| load_context(new_contexts, &project, &None, cx))
2966 .await
2967 .loaded_context;
2968
2969 assert!(!loaded_context.text.contains("file1.rs"));
2970 assert!(loaded_context.text.contains("file2.rs"));
2971 assert!(loaded_context.text.contains("file3.rs"));
2972 assert!(loaded_context.text.contains("file4.rs"));
2973
2974 let new_contexts = context_store.update(cx, |store, cx| {
2975 // Remove file4.rs
2976 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2977 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2978 });
2979 assert_eq!(new_contexts.len(), 2);
2980 let loaded_context = cx
2981 .update(|cx| load_context(new_contexts, &project, &None, cx))
2982 .await
2983 .loaded_context;
2984
2985 assert!(!loaded_context.text.contains("file1.rs"));
2986 assert!(loaded_context.text.contains("file2.rs"));
2987 assert!(loaded_context.text.contains("file3.rs"));
2988 assert!(!loaded_context.text.contains("file4.rs"));
2989
2990 let new_contexts = context_store.update(cx, |store, cx| {
2991 // Remove file3.rs
2992 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2993 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2994 });
2995 assert_eq!(new_contexts.len(), 1);
2996 let loaded_context = cx
2997 .update(|cx| load_context(new_contexts, &project, &None, cx))
2998 .await
2999 .loaded_context;
3000
3001 assert!(!loaded_context.text.contains("file1.rs"));
3002 assert!(loaded_context.text.contains("file2.rs"));
3003 assert!(!loaded_context.text.contains("file3.rs"));
3004 assert!(!loaded_context.text.contains("file4.rs"));
3005 }
3006
3007 #[gpui::test]
3008 async fn test_message_without_files(cx: &mut TestAppContext) {
3009 init_test_settings(cx);
3010
3011 let project = create_test_project(
3012 cx,
3013 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3014 )
3015 .await;
3016
3017 let (_, _thread_store, thread, _context_store, model) =
3018 setup_test_environment(cx, project.clone()).await;
3019
3020 // Insert user message without any context (empty context vector)
3021 let message_id = thread.update(cx, |thread, cx| {
3022 thread.insert_user_message(
3023 "What is the best way to learn Rust?",
3024 ContextLoadResult::default(),
3025 None,
3026 Vec::new(),
3027 cx,
3028 )
3029 });
3030
3031 // Check content and context in message object
3032 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3033
3034 // Context should be empty when no files are included
3035 assert_eq!(message.role, Role::User);
3036 assert_eq!(message.segments.len(), 1);
3037 assert_eq!(
3038 message.segments[0],
3039 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3040 );
3041 assert_eq!(message.loaded_context.text, "");
3042
3043 // Check message in request
3044 let request = thread.update(cx, |thread, cx| {
3045 thread.to_completion_request(model.clone(), cx)
3046 });
3047
3048 assert_eq!(request.messages.len(), 2);
3049 assert_eq!(
3050 request.messages[1].string_contents(),
3051 "What is the best way to learn Rust?"
3052 );
3053
3054 // Add second message, also without context
3055 let message2_id = thread.update(cx, |thread, cx| {
3056 thread.insert_user_message(
3057 "Are there any good books?",
3058 ContextLoadResult::default(),
3059 None,
3060 Vec::new(),
3061 cx,
3062 )
3063 });
3064
3065 let message2 =
3066 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3067 assert_eq!(message2.loaded_context.text, "");
3068
3069 // Check that both messages appear in the request
3070 let request = thread.update(cx, |thread, cx| {
3071 thread.to_completion_request(model.clone(), cx)
3072 });
3073
3074 assert_eq!(request.messages.len(), 3);
3075 assert_eq!(
3076 request.messages[1].string_contents(),
3077 "What is the best way to learn Rust?"
3078 );
3079 assert_eq!(
3080 request.messages[2].string_contents(),
3081 "Are there any good books?"
3082 );
3083 }
3084
3085 #[gpui::test]
3086 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3087 init_test_settings(cx);
3088
3089 let project = create_test_project(
3090 cx,
3091 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3092 )
3093 .await;
3094
3095 let (_workspace, _thread_store, thread, context_store, model) =
3096 setup_test_environment(cx, project.clone()).await;
3097
3098 // Open buffer and add it to context
3099 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3100 .await
3101 .unwrap();
3102
3103 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3104 let loaded_context = cx
3105 .update(|cx| load_context(vec![context], &project, &None, cx))
3106 .await;
3107
3108 // Insert user message with the buffer as context
3109 thread.update(cx, |thread, cx| {
3110 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3111 });
3112
3113 // Create a request and check that it doesn't have a stale buffer warning yet
3114 let initial_request = thread.update(cx, |thread, cx| {
3115 thread.to_completion_request(model.clone(), cx)
3116 });
3117
3118 // Make sure we don't have a stale file warning yet
3119 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3120 msg.string_contents()
3121 .contains("These files changed since last read:")
3122 });
3123 assert!(
3124 !has_stale_warning,
3125 "Should not have stale buffer warning before buffer is modified"
3126 );
3127
3128 // Modify the buffer
3129 buffer.update(cx, |buffer, cx| {
3130 // Find a position at the end of line 1
3131 buffer.edit(
3132 [(1..1, "\n println!(\"Added a new line\");\n")],
3133 None,
3134 cx,
3135 );
3136 });
3137
3138 // Insert another user message without context
3139 thread.update(cx, |thread, cx| {
3140 thread.insert_user_message(
3141 "What does the code do now?",
3142 ContextLoadResult::default(),
3143 None,
3144 Vec::new(),
3145 cx,
3146 )
3147 });
3148
3149 // Create a new request and check for the stale buffer warning
3150 let new_request = thread.update(cx, |thread, cx| {
3151 thread.to_completion_request(model.clone(), cx)
3152 });
3153
3154 // We should have a stale file warning as the last message
3155 let last_message = new_request
3156 .messages
3157 .last()
3158 .expect("Request should have messages");
3159
3160 // The last message should be the stale buffer notification
3161 assert_eq!(last_message.role, Role::User);
3162
3163 // Check the exact content of the message
3164 let expected_content = "These files changed since last read:\n- code.rs\n";
3165 assert_eq!(
3166 last_message.string_contents(),
3167 expected_content,
3168 "Last message should be exactly the stale buffer notification"
3169 );
3170 }
3171
3172 #[gpui::test]
3173 async fn test_temperature_setting(cx: &mut TestAppContext) {
3174 init_test_settings(cx);
3175
3176 let project = create_test_project(
3177 cx,
3178 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3179 )
3180 .await;
3181
3182 let (_workspace, _thread_store, thread, _context_store, model) =
3183 setup_test_environment(cx, project.clone()).await;
3184
3185 // Both model and provider
3186 cx.update(|cx| {
3187 AssistantSettings::override_global(
3188 AssistantSettings {
3189 model_parameters: vec![LanguageModelParameters {
3190 provider: Some(model.provider_id().0.to_string().into()),
3191 model: Some(model.id().0.clone()),
3192 temperature: Some(0.66),
3193 }],
3194 ..AssistantSettings::get_global(cx).clone()
3195 },
3196 cx,
3197 );
3198 });
3199
3200 let request = thread.update(cx, |thread, cx| {
3201 thread.to_completion_request(model.clone(), cx)
3202 });
3203 assert_eq!(request.temperature, Some(0.66));
3204
3205 // Only model
3206 cx.update(|cx| {
3207 AssistantSettings::override_global(
3208 AssistantSettings {
3209 model_parameters: vec![LanguageModelParameters {
3210 provider: None,
3211 model: Some(model.id().0.clone()),
3212 temperature: Some(0.66),
3213 }],
3214 ..AssistantSettings::get_global(cx).clone()
3215 },
3216 cx,
3217 );
3218 });
3219
3220 let request = thread.update(cx, |thread, cx| {
3221 thread.to_completion_request(model.clone(), cx)
3222 });
3223 assert_eq!(request.temperature, Some(0.66));
3224
3225 // Only provider
3226 cx.update(|cx| {
3227 AssistantSettings::override_global(
3228 AssistantSettings {
3229 model_parameters: vec![LanguageModelParameters {
3230 provider: Some(model.provider_id().0.to_string().into()),
3231 model: None,
3232 temperature: Some(0.66),
3233 }],
3234 ..AssistantSettings::get_global(cx).clone()
3235 },
3236 cx,
3237 );
3238 });
3239
3240 let request = thread.update(cx, |thread, cx| {
3241 thread.to_completion_request(model.clone(), cx)
3242 });
3243 assert_eq!(request.temperature, Some(0.66));
3244
3245 // Same model name, different provider
3246 cx.update(|cx| {
3247 AssistantSettings::override_global(
3248 AssistantSettings {
3249 model_parameters: vec![LanguageModelParameters {
3250 provider: Some("anthropic".into()),
3251 model: Some(model.id().0.clone()),
3252 temperature: Some(0.66),
3253 }],
3254 ..AssistantSettings::get_global(cx).clone()
3255 },
3256 cx,
3257 );
3258 });
3259
3260 let request = thread.update(cx, |thread, cx| {
3261 thread.to_completion_request(model.clone(), cx)
3262 });
3263 assert_eq!(request.temperature, None);
3264 }
3265
3266 #[gpui::test]
3267 async fn test_thread_summary(cx: &mut TestAppContext) {
3268 init_test_settings(cx);
3269
3270 let project = create_test_project(cx, json!({})).await;
3271
3272 let (_, _thread_store, thread, _context_store, model) =
3273 setup_test_environment(cx, project.clone()).await;
3274
3275 // Initial state should be pending
3276 thread.read_with(cx, |thread, _| {
3277 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3278 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3279 });
3280
3281 // Manually setting the summary should not be allowed in this state
3282 thread.update(cx, |thread, cx| {
3283 thread.set_summary("This should not work", cx);
3284 });
3285
3286 thread.read_with(cx, |thread, _| {
3287 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3288 });
3289
3290 // Send a message
3291 thread.update(cx, |thread, cx| {
3292 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3293 thread.send_to_model(model.clone(), None, cx);
3294 });
3295
3296 let fake_model = model.as_fake();
3297 simulate_successful_response(&fake_model, cx);
3298
3299 // Should start generating summary when there are >= 2 messages
3300 thread.read_with(cx, |thread, _| {
3301 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3302 });
3303
3304 // Should not be able to set the summary while generating
3305 thread.update(cx, |thread, cx| {
3306 thread.set_summary("This should not work either", cx);
3307 });
3308
3309 thread.read_with(cx, |thread, _| {
3310 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3311 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3312 });
3313
3314 cx.run_until_parked();
3315 fake_model.stream_last_completion_response("Brief".into());
3316 fake_model.stream_last_completion_response(" Introduction".into());
3317 fake_model.end_last_completion_stream();
3318 cx.run_until_parked();
3319
3320 // Summary should be set
3321 thread.read_with(cx, |thread, _| {
3322 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3323 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3324 });
3325
3326 // Now we should be able to set a summary
3327 thread.update(cx, |thread, cx| {
3328 thread.set_summary("Brief Intro", cx);
3329 });
3330
3331 thread.read_with(cx, |thread, _| {
3332 assert_eq!(thread.summary().or_default(), "Brief Intro");
3333 });
3334
3335 // Test setting an empty summary (should default to DEFAULT)
3336 thread.update(cx, |thread, cx| {
3337 thread.set_summary("", cx);
3338 });
3339
3340 thread.read_with(cx, |thread, _| {
3341 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3342 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3343 });
3344 }
3345
3346 #[gpui::test]
3347 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3348 init_test_settings(cx);
3349
3350 let project = create_test_project(cx, json!({})).await;
3351
3352 let (_, _thread_store, thread, _context_store, model) =
3353 setup_test_environment(cx, project.clone()).await;
3354
3355 test_summarize_error(&model, &thread, cx);
3356
3357 // Now we should be able to set a summary
3358 thread.update(cx, |thread, cx| {
3359 thread.set_summary("Brief Intro", cx);
3360 });
3361
3362 thread.read_with(cx, |thread, _| {
3363 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3364 assert_eq!(thread.summary().or_default(), "Brief Intro");
3365 });
3366 }
3367
3368 #[gpui::test]
3369 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3370 init_test_settings(cx);
3371
3372 let project = create_test_project(cx, json!({})).await;
3373
3374 let (_, _thread_store, thread, _context_store, model) =
3375 setup_test_environment(cx, project.clone()).await;
3376
3377 test_summarize_error(&model, &thread, cx);
3378
3379 // Sending another message should not trigger another summarize request
3380 thread.update(cx, |thread, cx| {
3381 thread.insert_user_message(
3382 "How are you?",
3383 ContextLoadResult::default(),
3384 None,
3385 vec![],
3386 cx,
3387 );
3388 thread.send_to_model(model.clone(), None, cx);
3389 });
3390
3391 let fake_model = model.as_fake();
3392 simulate_successful_response(&fake_model, cx);
3393
3394 thread.read_with(cx, |thread, _| {
3395 // State is still Error, not Generating
3396 assert!(matches!(thread.summary(), ThreadSummary::Error));
3397 });
3398
3399 // But the summarize request can be invoked manually
3400 thread.update(cx, |thread, cx| {
3401 thread.summarize(cx);
3402 });
3403
3404 thread.read_with(cx, |thread, _| {
3405 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3406 });
3407
3408 cx.run_until_parked();
3409 fake_model.stream_last_completion_response("A successful summary".into());
3410 fake_model.end_last_completion_stream();
3411 cx.run_until_parked();
3412
3413 thread.read_with(cx, |thread, _| {
3414 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3415 assert_eq!(thread.summary().or_default(), "A successful summary");
3416 });
3417 }
3418
3419 fn test_summarize_error(
3420 model: &Arc<dyn LanguageModel>,
3421 thread: &Entity<Thread>,
3422 cx: &mut TestAppContext,
3423 ) {
3424 thread.update(cx, |thread, cx| {
3425 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3426 thread.send_to_model(model.clone(), None, cx);
3427 });
3428
3429 let fake_model = model.as_fake();
3430 simulate_successful_response(&fake_model, cx);
3431
3432 thread.read_with(cx, |thread, _| {
3433 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3434 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3435 });
3436
3437 // Simulate summary request ending
3438 cx.run_until_parked();
3439 fake_model.end_last_completion_stream();
3440 cx.run_until_parked();
3441
3442 // State is set to Error and default message
3443 thread.read_with(cx, |thread, _| {
3444 assert!(matches!(thread.summary(), ThreadSummary::Error));
3445 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3446 });
3447 }
3448
3449 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3450 cx.run_until_parked();
3451 fake_model.stream_last_completion_response("Assistant response".into());
3452 fake_model.end_last_completion_stream();
3453 cx.run_until_parked();
3454 }
3455
3456 fn init_test_settings(cx: &mut TestAppContext) {
3457 cx.update(|cx| {
3458 let settings_store = SettingsStore::test(cx);
3459 cx.set_global(settings_store);
3460 language::init(cx);
3461 Project::init_settings(cx);
3462 AssistantSettings::register(cx);
3463 prompt_store::init(cx);
3464 thread_store::init(cx);
3465 workspace::init_settings(cx);
3466 language_model::init_settings(cx);
3467 ThemeSettings::register(cx);
3468 EditorSettings::register(cx);
3469 ToolRegistry::default_global(cx);
3470 });
3471 }
3472
3473 // Helper to create a test project with test files
3474 async fn create_test_project(
3475 cx: &mut TestAppContext,
3476 files: serde_json::Value,
3477 ) -> Entity<Project> {
3478 let fs = FakeFs::new(cx.executor());
3479 fs.insert_tree(path!("/test"), files).await;
3480 Project::test(fs, [path!("/test").as_ref()], cx).await
3481 }
3482
3483 async fn setup_test_environment(
3484 cx: &mut TestAppContext,
3485 project: Entity<Project>,
3486 ) -> (
3487 Entity<Workspace>,
3488 Entity<ThreadStore>,
3489 Entity<Thread>,
3490 Entity<ContextStore>,
3491 Arc<dyn LanguageModel>,
3492 ) {
3493 let (workspace, cx) =
3494 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3495
3496 let thread_store = cx
3497 .update(|_, cx| {
3498 ThreadStore::load(
3499 project.clone(),
3500 cx.new(|_| ToolWorkingSet::default()),
3501 None,
3502 Arc::new(PromptBuilder::new(None).unwrap()),
3503 cx,
3504 )
3505 })
3506 .await
3507 .unwrap();
3508
3509 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3510 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3511
3512 let provider = Arc::new(FakeLanguageModelProvider);
3513 let model = provider.test_model();
3514 let model: Arc<dyn LanguageModel> = Arc::new(model);
3515
3516 cx.update(|_, cx| {
3517 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3518 registry.set_default_model(
3519 Some(ConfiguredModel {
3520 provider: provider.clone(),
3521 model: model.clone(),
3522 }),
3523 cx,
3524 );
3525 registry.set_thread_summary_model(
3526 Some(ConfiguredModel {
3527 provider,
3528 model: model.clone(),
3529 }),
3530 cx,
3531 );
3532 })
3533 });
3534
3535 (workspace, thread_store, thread, context_store, model)
3536 }
3537
3538 async fn add_file_to_context(
3539 project: &Entity<Project>,
3540 context_store: &Entity<ContextStore>,
3541 path: &str,
3542 cx: &mut TestAppContext,
3543 ) -> Result<Entity<language::Buffer>> {
3544 let buffer_path = project
3545 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3546 .unwrap();
3547
3548 let buffer = project
3549 .update(cx, |project, cx| {
3550 project.open_buffer(buffer_path.clone(), cx)
3551 })
3552 .await
3553 .unwrap();
3554
3555 context_store.update(cx, |context_store, cx| {
3556 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3557 });
3558
3559 Ok(buffer)
3560 }
3561}