1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MaxMonthlySpendReachedError,
26 MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role,
27 SelectedModel, StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: &mut Window,
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(str) => Some(str),
885 LanguageModelToolResultContent::Image(_) => {
886 // TODO: We should display image
887 None
888 }
889 }
890 }
891
892 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
893 self.tool_use.tool_result_card(id).cloned()
894 }
895
896 /// Return tools that are both enabled and supported by the model
897 pub fn available_tools(
898 &self,
899 cx: &App,
900 model: Arc<dyn LanguageModel>,
901 ) -> Vec<LanguageModelRequestTool> {
902 if model.supports_tools() {
903 self.tools()
904 .read(cx)
905 .enabled_tools(cx)
906 .into_iter()
907 .filter_map(|tool| {
908 // Skip tools that cannot be supported
909 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
910 Some(LanguageModelRequestTool {
911 name: tool.name(),
912 description: tool.description(),
913 input_schema,
914 })
915 })
916 .collect()
917 } else {
918 Vec::default()
919 }
920 }
921
922 pub fn insert_user_message(
923 &mut self,
924 text: impl Into<String>,
925 loaded_context: ContextLoadResult,
926 git_checkpoint: Option<GitStoreCheckpoint>,
927 creases: Vec<MessageCrease>,
928 cx: &mut Context<Self>,
929 ) -> MessageId {
930 if !loaded_context.referenced_buffers.is_empty() {
931 self.action_log.update(cx, |log, cx| {
932 for buffer in loaded_context.referenced_buffers {
933 log.buffer_read(buffer, cx);
934 }
935 });
936 }
937
938 let message_id = self.insert_message(
939 Role::User,
940 vec![MessageSegment::Text(text.into())],
941 loaded_context.loaded_context,
942 creases,
943 cx,
944 );
945
946 if let Some(git_checkpoint) = git_checkpoint {
947 self.pending_checkpoint = Some(ThreadCheckpoint {
948 message_id,
949 git_checkpoint,
950 });
951 }
952
953 self.auto_capture_telemetry(cx);
954
955 message_id
956 }
957
958 pub fn insert_assistant_message(
959 &mut self,
960 segments: Vec<MessageSegment>,
961 cx: &mut Context<Self>,
962 ) -> MessageId {
963 self.insert_message(
964 Role::Assistant,
965 segments,
966 LoadedContext::default(),
967 Vec::new(),
968 cx,
969 )
970 }
971
972 pub fn insert_message(
973 &mut self,
974 role: Role,
975 segments: Vec<MessageSegment>,
976 loaded_context: LoadedContext,
977 creases: Vec<MessageCrease>,
978 cx: &mut Context<Self>,
979 ) -> MessageId {
980 let id = self.next_message_id.post_inc();
981 self.messages.push(Message {
982 id,
983 role,
984 segments,
985 loaded_context,
986 creases,
987 });
988 self.touch_updated_at();
989 cx.emit(ThreadEvent::MessageAdded(id));
990 id
991 }
992
993 pub fn edit_message(
994 &mut self,
995 id: MessageId,
996 new_role: Role,
997 new_segments: Vec<MessageSegment>,
998 loaded_context: Option<LoadedContext>,
999 cx: &mut Context<Self>,
1000 ) -> bool {
1001 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1002 return false;
1003 };
1004 message.role = new_role;
1005 message.segments = new_segments;
1006 if let Some(context) = loaded_context {
1007 message.loaded_context = context;
1008 }
1009 self.touch_updated_at();
1010 cx.emit(ThreadEvent::MessageEdited(id));
1011 true
1012 }
1013
1014 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1015 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1016 return false;
1017 };
1018 self.messages.remove(index);
1019 self.touch_updated_at();
1020 cx.emit(ThreadEvent::MessageDeleted(id));
1021 true
1022 }
1023
1024 /// Returns the representation of this [`Thread`] in a textual form.
1025 ///
1026 /// This is the representation we use when attaching a thread as context to another thread.
1027 pub fn text(&self) -> String {
1028 let mut text = String::new();
1029
1030 for message in &self.messages {
1031 text.push_str(match message.role {
1032 language_model::Role::User => "User:",
1033 language_model::Role::Assistant => "Agent:",
1034 language_model::Role::System => "System:",
1035 });
1036 text.push('\n');
1037
1038 for segment in &message.segments {
1039 match segment {
1040 MessageSegment::Text(content) => text.push_str(content),
1041 MessageSegment::Thinking { text: content, .. } => {
1042 text.push_str(&format!("<think>{}</think>", content))
1043 }
1044 MessageSegment::RedactedThinking(_) => {}
1045 }
1046 }
1047 text.push('\n');
1048 }
1049
1050 text
1051 }
1052
1053 /// Serializes this thread into a format for storage or telemetry.
1054 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1055 let initial_project_snapshot = self.initial_project_snapshot.clone();
1056 cx.spawn(async move |this, cx| {
1057 let initial_project_snapshot = initial_project_snapshot.await;
1058 this.read_with(cx, |this, cx| SerializedThread {
1059 version: SerializedThread::VERSION.to_string(),
1060 summary: this.summary().or_default(),
1061 updated_at: this.updated_at(),
1062 messages: this
1063 .messages()
1064 .map(|message| SerializedMessage {
1065 id: message.id,
1066 role: message.role,
1067 segments: message
1068 .segments
1069 .iter()
1070 .map(|segment| match segment {
1071 MessageSegment::Text(text) => {
1072 SerializedMessageSegment::Text { text: text.clone() }
1073 }
1074 MessageSegment::Thinking { text, signature } => {
1075 SerializedMessageSegment::Thinking {
1076 text: text.clone(),
1077 signature: signature.clone(),
1078 }
1079 }
1080 MessageSegment::RedactedThinking(data) => {
1081 SerializedMessageSegment::RedactedThinking {
1082 data: data.clone(),
1083 }
1084 }
1085 })
1086 .collect(),
1087 tool_uses: this
1088 .tool_uses_for_message(message.id, cx)
1089 .into_iter()
1090 .map(|tool_use| SerializedToolUse {
1091 id: tool_use.id,
1092 name: tool_use.name,
1093 input: tool_use.input,
1094 })
1095 .collect(),
1096 tool_results: this
1097 .tool_results_for_message(message.id)
1098 .into_iter()
1099 .map(|tool_result| SerializedToolResult {
1100 tool_use_id: tool_result.tool_use_id.clone(),
1101 is_error: tool_result.is_error,
1102 content: tool_result.content.clone(),
1103 output: tool_result.output.clone(),
1104 })
1105 .collect(),
1106 context: message.loaded_context.text.clone(),
1107 creases: message
1108 .creases
1109 .iter()
1110 .map(|crease| SerializedCrease {
1111 start: crease.range.start,
1112 end: crease.range.end,
1113 icon_path: crease.metadata.icon_path.clone(),
1114 label: crease.metadata.label.clone(),
1115 })
1116 .collect(),
1117 })
1118 .collect(),
1119 initial_project_snapshot,
1120 cumulative_token_usage: this.cumulative_token_usage,
1121 request_token_usage: this.request_token_usage.clone(),
1122 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1123 exceeded_window_error: this.exceeded_window_error.clone(),
1124 model: this
1125 .configured_model
1126 .as_ref()
1127 .map(|model| SerializedLanguageModel {
1128 provider: model.provider.id().0.to_string(),
1129 model: model.model.id().0.to_string(),
1130 }),
1131 completion_mode: Some(this.completion_mode),
1132 })
1133 })
1134 }
1135
1136 pub fn remaining_turns(&self) -> u32 {
1137 self.remaining_turns
1138 }
1139
1140 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1141 self.remaining_turns = remaining_turns;
1142 }
1143
1144 pub fn send_to_model(
1145 &mut self,
1146 model: Arc<dyn LanguageModel>,
1147 window: Option<AnyWindowHandle>,
1148 cx: &mut Context<Self>,
1149 ) {
1150 if self.remaining_turns == 0 {
1151 return;
1152 }
1153
1154 self.remaining_turns -= 1;
1155
1156 let request = self.to_completion_request(model.clone(), cx);
1157
1158 self.stream_completion(request, model, window, cx);
1159 }
1160
1161 pub fn used_tools_since_last_user_message(&self) -> bool {
1162 for message in self.messages.iter().rev() {
1163 if self.tool_use.message_has_tool_results(message.id) {
1164 return true;
1165 } else if message.role == Role::User {
1166 return false;
1167 }
1168 }
1169
1170 false
1171 }
1172
1173 pub fn to_completion_request(
1174 &self,
1175 model: Arc<dyn LanguageModel>,
1176 cx: &mut Context<Self>,
1177 ) -> LanguageModelRequest {
1178 let mut request = LanguageModelRequest {
1179 thread_id: Some(self.id.to_string()),
1180 prompt_id: Some(self.last_prompt_id.to_string()),
1181 mode: None,
1182 messages: vec![],
1183 tools: Vec::new(),
1184 tool_choice: None,
1185 stop: Vec::new(),
1186 temperature: AssistantSettings::temperature_for_model(&model, cx),
1187 };
1188
1189 let available_tools = self.available_tools(cx, model.clone());
1190 let available_tool_names = available_tools
1191 .iter()
1192 .map(|tool| tool.name.clone())
1193 .collect();
1194
1195 let model_context = &ModelContext {
1196 available_tools: available_tool_names,
1197 };
1198
1199 if let Some(project_context) = self.project_context.borrow().as_ref() {
1200 match self
1201 .prompt_builder
1202 .generate_assistant_system_prompt(project_context, model_context)
1203 {
1204 Err(err) => {
1205 let message = format!("{err:?}").into();
1206 log::error!("{message}");
1207 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1208 header: "Error generating system prompt".into(),
1209 message,
1210 }));
1211 }
1212 Ok(system_prompt) => {
1213 request.messages.push(LanguageModelRequestMessage {
1214 role: Role::System,
1215 content: vec![MessageContent::Text(system_prompt)],
1216 cache: true,
1217 });
1218 }
1219 }
1220 } else {
1221 let message = "Context for system prompt unexpectedly not ready.".into();
1222 log::error!("{message}");
1223 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1224 header: "Error generating system prompt".into(),
1225 message,
1226 }));
1227 }
1228
1229 let mut message_ix_to_cache = None;
1230 for message in &self.messages {
1231 let mut request_message = LanguageModelRequestMessage {
1232 role: message.role,
1233 content: Vec::new(),
1234 cache: false,
1235 };
1236
1237 message
1238 .loaded_context
1239 .add_to_request_message(&mut request_message);
1240
1241 for segment in &message.segments {
1242 match segment {
1243 MessageSegment::Text(text) => {
1244 if !text.is_empty() {
1245 request_message
1246 .content
1247 .push(MessageContent::Text(text.into()));
1248 }
1249 }
1250 MessageSegment::Thinking { text, signature } => {
1251 if !text.is_empty() {
1252 request_message.content.push(MessageContent::Thinking {
1253 text: text.into(),
1254 signature: signature.clone(),
1255 });
1256 }
1257 }
1258 MessageSegment::RedactedThinking(data) => {
1259 request_message
1260 .content
1261 .push(MessageContent::RedactedThinking(data.clone()));
1262 }
1263 };
1264 }
1265
1266 let mut cache_message = true;
1267 let mut tool_results_message = LanguageModelRequestMessage {
1268 role: Role::User,
1269 content: Vec::new(),
1270 cache: false,
1271 };
1272 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1273 if let Some(tool_result) = tool_result {
1274 request_message
1275 .content
1276 .push(MessageContent::ToolUse(tool_use.clone()));
1277 tool_results_message
1278 .content
1279 .push(MessageContent::ToolResult(LanguageModelToolResult {
1280 tool_use_id: tool_use.id.clone(),
1281 tool_name: tool_result.tool_name.clone(),
1282 is_error: tool_result.is_error,
1283 content: if tool_result.content.is_empty() {
1284 // Surprisingly, the API fails if we return an empty string here.
1285 // It thinks we are sending a tool use without a tool result.
1286 "<Tool returned an empty string>".into()
1287 } else {
1288 tool_result.content.clone()
1289 },
1290 output: None,
1291 }));
1292 } else {
1293 cache_message = false;
1294 log::debug!(
1295 "skipped tool use {:?} because it is still pending",
1296 tool_use
1297 );
1298 }
1299 }
1300
1301 if cache_message {
1302 message_ix_to_cache = Some(request.messages.len());
1303 }
1304 request.messages.push(request_message);
1305
1306 if !tool_results_message.content.is_empty() {
1307 if cache_message {
1308 message_ix_to_cache = Some(request.messages.len());
1309 }
1310 request.messages.push(tool_results_message);
1311 }
1312 }
1313
1314 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1315 if let Some(message_ix_to_cache) = message_ix_to_cache {
1316 request.messages[message_ix_to_cache].cache = true;
1317 }
1318
1319 self.attached_tracked_files_state(&mut request.messages, cx);
1320
1321 request.tools = available_tools;
1322 request.mode = if model.supports_max_mode() {
1323 Some(self.completion_mode.into())
1324 } else {
1325 Some(CompletionMode::Normal.into())
1326 };
1327
1328 request
1329 }
1330
1331 fn to_summarize_request(
1332 &self,
1333 model: &Arc<dyn LanguageModel>,
1334 added_user_message: String,
1335 cx: &App,
1336 ) -> LanguageModelRequest {
1337 let mut request = LanguageModelRequest {
1338 thread_id: None,
1339 prompt_id: None,
1340 mode: None,
1341 messages: vec![],
1342 tools: Vec::new(),
1343 tool_choice: None,
1344 stop: Vec::new(),
1345 temperature: AssistantSettings::temperature_for_model(model, cx),
1346 };
1347
1348 for message in &self.messages {
1349 let mut request_message = LanguageModelRequestMessage {
1350 role: message.role,
1351 content: Vec::new(),
1352 cache: false,
1353 };
1354
1355 for segment in &message.segments {
1356 match segment {
1357 MessageSegment::Text(text) => request_message
1358 .content
1359 .push(MessageContent::Text(text.clone())),
1360 MessageSegment::Thinking { .. } => {}
1361 MessageSegment::RedactedThinking(_) => {}
1362 }
1363 }
1364
1365 if request_message.content.is_empty() {
1366 continue;
1367 }
1368
1369 request.messages.push(request_message);
1370 }
1371
1372 request.messages.push(LanguageModelRequestMessage {
1373 role: Role::User,
1374 content: vec![MessageContent::Text(added_user_message)],
1375 cache: false,
1376 });
1377
1378 request
1379 }
1380
1381 fn attached_tracked_files_state(
1382 &self,
1383 messages: &mut Vec<LanguageModelRequestMessage>,
1384 cx: &App,
1385 ) {
1386 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1387
1388 let mut stale_message = String::new();
1389
1390 let action_log = self.action_log.read(cx);
1391
1392 for stale_file in action_log.stale_buffers(cx) {
1393 let Some(file) = stale_file.read(cx).file() else {
1394 continue;
1395 };
1396
1397 if stale_message.is_empty() {
1398 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1399 }
1400
1401 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1402 }
1403
1404 let mut content = Vec::with_capacity(2);
1405
1406 if !stale_message.is_empty() {
1407 content.push(stale_message.into());
1408 }
1409
1410 if !content.is_empty() {
1411 let context_message = LanguageModelRequestMessage {
1412 role: Role::User,
1413 content,
1414 cache: false,
1415 };
1416
1417 messages.push(context_message);
1418 }
1419 }
1420
1421 pub fn stream_completion(
1422 &mut self,
1423 request: LanguageModelRequest,
1424 model: Arc<dyn LanguageModel>,
1425 window: Option<AnyWindowHandle>,
1426 cx: &mut Context<Self>,
1427 ) {
1428 self.tool_use_limit_reached = false;
1429
1430 let pending_completion_id = post_inc(&mut self.completion_count);
1431 let mut request_callback_parameters = if self.request_callback.is_some() {
1432 Some((request.clone(), Vec::new()))
1433 } else {
1434 None
1435 };
1436 let prompt_id = self.last_prompt_id.clone();
1437 let tool_use_metadata = ToolUseMetadata {
1438 model: model.clone(),
1439 thread_id: self.id.clone(),
1440 prompt_id: prompt_id.clone(),
1441 };
1442
1443 self.last_received_chunk_at = Some(Instant::now());
1444
1445 let task = cx.spawn(async move |thread, cx| {
1446 let stream_completion_future = model.stream_completion(request, &cx);
1447 let initial_token_usage =
1448 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1449 let stream_completion = async {
1450 let mut events = stream_completion_future.await?;
1451
1452 let mut stop_reason = StopReason::EndTurn;
1453 let mut current_token_usage = TokenUsage::default();
1454
1455 thread
1456 .update(cx, |_thread, cx| {
1457 cx.emit(ThreadEvent::NewRequest);
1458 })
1459 .ok();
1460
1461 let mut request_assistant_message_id = None;
1462
1463 while let Some(event) = events.next().await {
1464 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1465 response_events
1466 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1467 }
1468
1469 thread.update(cx, |thread, cx| {
1470 let event = match event {
1471 Ok(event) => event,
1472 Err(LanguageModelCompletionError::BadInputJson {
1473 id,
1474 tool_name,
1475 raw_input: invalid_input_json,
1476 json_parse_error,
1477 }) => {
1478 thread.receive_invalid_tool_json(
1479 id,
1480 tool_name,
1481 invalid_input_json,
1482 json_parse_error,
1483 window,
1484 cx,
1485 );
1486 return Ok(());
1487 }
1488 Err(LanguageModelCompletionError::Other(error)) => {
1489 return Err(error);
1490 }
1491 };
1492
1493 match event {
1494 LanguageModelCompletionEvent::StartMessage { .. } => {
1495 request_assistant_message_id =
1496 Some(thread.insert_assistant_message(
1497 vec![MessageSegment::Text(String::new())],
1498 cx,
1499 ));
1500 }
1501 LanguageModelCompletionEvent::Stop(reason) => {
1502 stop_reason = reason;
1503 }
1504 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1505 thread.update_token_usage_at_last_message(token_usage);
1506 thread.cumulative_token_usage = thread.cumulative_token_usage
1507 + token_usage
1508 - current_token_usage;
1509 current_token_usage = token_usage;
1510 }
1511 LanguageModelCompletionEvent::Text(chunk) => {
1512 thread.received_chunk();
1513
1514 cx.emit(ThreadEvent::ReceivedTextChunk);
1515 if let Some(last_message) = thread.messages.last_mut() {
1516 if last_message.role == Role::Assistant
1517 && !thread.tool_use.has_tool_results(last_message.id)
1518 {
1519 last_message.push_text(&chunk);
1520 cx.emit(ThreadEvent::StreamedAssistantText(
1521 last_message.id,
1522 chunk,
1523 ));
1524 } else {
1525 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1526 // of a new Assistant response.
1527 //
1528 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1529 // will result in duplicating the text of the chunk in the rendered Markdown.
1530 request_assistant_message_id =
1531 Some(thread.insert_assistant_message(
1532 vec![MessageSegment::Text(chunk.to_string())],
1533 cx,
1534 ));
1535 };
1536 }
1537 }
1538 LanguageModelCompletionEvent::Thinking {
1539 text: chunk,
1540 signature,
1541 } => {
1542 thread.received_chunk();
1543
1544 if let Some(last_message) = thread.messages.last_mut() {
1545 if last_message.role == Role::Assistant
1546 && !thread.tool_use.has_tool_results(last_message.id)
1547 {
1548 last_message.push_thinking(&chunk, signature);
1549 cx.emit(ThreadEvent::StreamedAssistantThinking(
1550 last_message.id,
1551 chunk,
1552 ));
1553 } else {
1554 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1555 // of a new Assistant response.
1556 //
1557 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1558 // will result in duplicating the text of the chunk in the rendered Markdown.
1559 request_assistant_message_id =
1560 Some(thread.insert_assistant_message(
1561 vec![MessageSegment::Thinking {
1562 text: chunk.to_string(),
1563 signature,
1564 }],
1565 cx,
1566 ));
1567 };
1568 }
1569 }
1570 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1571 let last_assistant_message_id = request_assistant_message_id
1572 .unwrap_or_else(|| {
1573 let new_assistant_message_id =
1574 thread.insert_assistant_message(vec![], cx);
1575 request_assistant_message_id =
1576 Some(new_assistant_message_id);
1577 new_assistant_message_id
1578 });
1579
1580 let tool_use_id = tool_use.id.clone();
1581 let streamed_input = if tool_use.is_input_complete {
1582 None
1583 } else {
1584 Some((&tool_use.input).clone())
1585 };
1586
1587 let ui_text = thread.tool_use.request_tool_use(
1588 last_assistant_message_id,
1589 tool_use,
1590 tool_use_metadata.clone(),
1591 cx,
1592 );
1593
1594 if let Some(input) = streamed_input {
1595 cx.emit(ThreadEvent::StreamedToolUse {
1596 tool_use_id,
1597 ui_text,
1598 input,
1599 });
1600 }
1601 }
1602 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1603 if let Some(completion) = thread
1604 .pending_completions
1605 .iter_mut()
1606 .find(|completion| completion.id == pending_completion_id)
1607 {
1608 match status_update {
1609 CompletionRequestStatus::Queued {
1610 position,
1611 } => {
1612 completion.queue_state = QueueState::Queued { position };
1613 }
1614 CompletionRequestStatus::Started => {
1615 completion.queue_state = QueueState::Started;
1616 }
1617 CompletionRequestStatus::Failed {
1618 code, message, request_id
1619 } => {
1620 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1621 }
1622 CompletionRequestStatus::UsageUpdated {
1623 amount, limit
1624 } => {
1625 let usage = RequestUsage { limit, amount: amount as i32 };
1626
1627 thread.last_usage = Some(usage);
1628 }
1629 CompletionRequestStatus::ToolUseLimitReached => {
1630 thread.tool_use_limit_reached = true;
1631 }
1632 }
1633 }
1634 }
1635 }
1636
1637 thread.touch_updated_at();
1638 cx.emit(ThreadEvent::StreamedCompletion);
1639 cx.notify();
1640
1641 thread.auto_capture_telemetry(cx);
1642 Ok(())
1643 })??;
1644
1645 smol::future::yield_now().await;
1646 }
1647
1648 thread.update(cx, |thread, cx| {
1649 thread.last_received_chunk_at = None;
1650 thread
1651 .pending_completions
1652 .retain(|completion| completion.id != pending_completion_id);
1653
1654 // If there is a response without tool use, summarize the message. Otherwise,
1655 // allow two tool uses before summarizing.
1656 if matches!(thread.summary, ThreadSummary::Pending)
1657 && thread.messages.len() >= 2
1658 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1659 {
1660 thread.summarize(cx);
1661 }
1662 })?;
1663
1664 anyhow::Ok(stop_reason)
1665 };
1666
1667 let result = stream_completion.await;
1668
1669 thread
1670 .update(cx, |thread, cx| {
1671 thread.finalize_pending_checkpoint(cx);
1672 match result.as_ref() {
1673 Ok(stop_reason) => match stop_reason {
1674 StopReason::ToolUse => {
1675 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1676 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1677 }
1678 StopReason::EndTurn | StopReason::MaxTokens => {
1679 thread.project.update(cx, |project, cx| {
1680 project.set_agent_location(None, cx);
1681 });
1682 }
1683 },
1684 Err(error) => {
1685 thread.project.update(cx, |project, cx| {
1686 project.set_agent_location(None, cx);
1687 });
1688
1689 if error.is::<PaymentRequiredError>() {
1690 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1691 } else if error.is::<MaxMonthlySpendReachedError>() {
1692 cx.emit(ThreadEvent::ShowError(
1693 ThreadError::MaxMonthlySpendReached,
1694 ));
1695 } else if let Some(error) =
1696 error.downcast_ref::<ModelRequestLimitReachedError>()
1697 {
1698 cx.emit(ThreadEvent::ShowError(
1699 ThreadError::ModelRequestLimitReached { plan: error.plan },
1700 ));
1701 } else if let Some(known_error) =
1702 error.downcast_ref::<LanguageModelKnownError>()
1703 {
1704 match known_error {
1705 LanguageModelKnownError::ContextWindowLimitExceeded {
1706 tokens,
1707 } => {
1708 thread.exceeded_window_error = Some(ExceededWindowError {
1709 model_id: model.id(),
1710 token_count: *tokens,
1711 });
1712 cx.notify();
1713 }
1714 }
1715 } else {
1716 let error_message = error
1717 .chain()
1718 .map(|err| err.to_string())
1719 .collect::<Vec<_>>()
1720 .join("\n");
1721 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1722 header: "Error interacting with language model".into(),
1723 message: SharedString::from(error_message.clone()),
1724 }));
1725 }
1726
1727 thread.cancel_last_completion(window, cx);
1728 }
1729 }
1730 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1731
1732 if let Some((request_callback, (request, response_events))) = thread
1733 .request_callback
1734 .as_mut()
1735 .zip(request_callback_parameters.as_ref())
1736 {
1737 request_callback(request, response_events);
1738 }
1739
1740 thread.auto_capture_telemetry(cx);
1741
1742 if let Ok(initial_usage) = initial_token_usage {
1743 let usage = thread.cumulative_token_usage - initial_usage;
1744
1745 telemetry::event!(
1746 "Assistant Thread Completion",
1747 thread_id = thread.id().to_string(),
1748 prompt_id = prompt_id,
1749 model = model.telemetry_id(),
1750 model_provider = model.provider_id().to_string(),
1751 input_tokens = usage.input_tokens,
1752 output_tokens = usage.output_tokens,
1753 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1754 cache_read_input_tokens = usage.cache_read_input_tokens,
1755 );
1756 }
1757 })
1758 .ok();
1759 });
1760
1761 self.pending_completions.push(PendingCompletion {
1762 id: pending_completion_id,
1763 queue_state: QueueState::Sending,
1764 _task: task,
1765 });
1766 }
1767
1768 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1769 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1770 println!("No thread summary model");
1771 return;
1772 };
1773
1774 if !model.provider.is_authenticated(cx) {
1775 return;
1776 }
1777
1778 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1779 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1780 If the conversation is about a specific subject, include it in the title. \
1781 Be descriptive. DO NOT speak in the first person.";
1782
1783 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1784
1785 self.summary = ThreadSummary::Generating;
1786
1787 self.pending_summary = cx.spawn(async move |this, cx| {
1788 let result = async {
1789 let mut messages = model.model.stream_completion(request, &cx).await?;
1790
1791 let mut new_summary = String::new();
1792 while let Some(event) = messages.next().await {
1793 let Ok(event) = event else {
1794 continue;
1795 };
1796 let text = match event {
1797 LanguageModelCompletionEvent::Text(text) => text,
1798 LanguageModelCompletionEvent::StatusUpdate(
1799 CompletionRequestStatus::UsageUpdated { amount, limit },
1800 ) => {
1801 this.update(cx, |thread, _cx| {
1802 thread.last_usage = Some(RequestUsage {
1803 limit,
1804 amount: amount as i32,
1805 });
1806 })?;
1807 continue;
1808 }
1809 _ => continue,
1810 };
1811
1812 let mut lines = text.lines();
1813 new_summary.extend(lines.next());
1814
1815 // Stop if the LLM generated multiple lines.
1816 if lines.next().is_some() {
1817 break;
1818 }
1819 }
1820
1821 anyhow::Ok(new_summary)
1822 }
1823 .await;
1824
1825 this.update(cx, |this, cx| {
1826 match result {
1827 Ok(new_summary) => {
1828 if new_summary.is_empty() {
1829 this.summary = ThreadSummary::Error;
1830 } else {
1831 this.summary = ThreadSummary::Ready(new_summary.into());
1832 }
1833 }
1834 Err(err) => {
1835 this.summary = ThreadSummary::Error;
1836 log::error!("Failed to generate thread summary: {}", err);
1837 }
1838 }
1839 cx.emit(ThreadEvent::SummaryGenerated);
1840 })
1841 .log_err()?;
1842
1843 Some(())
1844 });
1845 }
1846
1847 pub fn start_generating_detailed_summary_if_needed(
1848 &mut self,
1849 thread_store: WeakEntity<ThreadStore>,
1850 cx: &mut Context<Self>,
1851 ) {
1852 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1853 return;
1854 };
1855
1856 match &*self.detailed_summary_rx.borrow() {
1857 DetailedSummaryState::Generating { message_id, .. }
1858 | DetailedSummaryState::Generated { message_id, .. }
1859 if *message_id == last_message_id =>
1860 {
1861 // Already up-to-date
1862 return;
1863 }
1864 _ => {}
1865 }
1866
1867 let Some(ConfiguredModel { model, provider }) =
1868 LanguageModelRegistry::read_global(cx).thread_summary_model()
1869 else {
1870 return;
1871 };
1872
1873 if !provider.is_authenticated(cx) {
1874 return;
1875 }
1876
1877 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1878 1. A brief overview of what was discussed\n\
1879 2. Key facts or information discovered\n\
1880 3. Outcomes or conclusions reached\n\
1881 4. Any action items or next steps if any\n\
1882 Format it in Markdown with headings and bullet points.";
1883
1884 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1885
1886 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1887 message_id: last_message_id,
1888 };
1889
1890 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1891 // be better to allow the old task to complete, but this would require logic for choosing
1892 // which result to prefer (the old task could complete after the new one, resulting in a
1893 // stale summary).
1894 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1895 let stream = model.stream_completion_text(request, &cx);
1896 let Some(mut messages) = stream.await.log_err() else {
1897 thread
1898 .update(cx, |thread, _cx| {
1899 *thread.detailed_summary_tx.borrow_mut() =
1900 DetailedSummaryState::NotGenerated;
1901 })
1902 .ok()?;
1903 return None;
1904 };
1905
1906 let mut new_detailed_summary = String::new();
1907
1908 while let Some(chunk) = messages.stream.next().await {
1909 if let Some(chunk) = chunk.log_err() {
1910 new_detailed_summary.push_str(&chunk);
1911 }
1912 }
1913
1914 thread
1915 .update(cx, |thread, _cx| {
1916 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1917 text: new_detailed_summary.into(),
1918 message_id: last_message_id,
1919 };
1920 })
1921 .ok()?;
1922
1923 // Save thread so its summary can be reused later
1924 if let Some(thread) = thread.upgrade() {
1925 if let Ok(Ok(save_task)) = cx.update(|cx| {
1926 thread_store
1927 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1928 }) {
1929 save_task.await.log_err();
1930 }
1931 }
1932
1933 Some(())
1934 });
1935 }
1936
1937 pub async fn wait_for_detailed_summary_or_text(
1938 this: &Entity<Self>,
1939 cx: &mut AsyncApp,
1940 ) -> Option<SharedString> {
1941 let mut detailed_summary_rx = this
1942 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1943 .ok()?;
1944 loop {
1945 match detailed_summary_rx.recv().await? {
1946 DetailedSummaryState::Generating { .. } => {}
1947 DetailedSummaryState::NotGenerated => {
1948 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1949 }
1950 DetailedSummaryState::Generated { text, .. } => return Some(text),
1951 }
1952 }
1953 }
1954
1955 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1956 self.detailed_summary_rx
1957 .borrow()
1958 .text()
1959 .unwrap_or_else(|| self.text().into())
1960 }
1961
1962 pub fn is_generating_detailed_summary(&self) -> bool {
1963 matches!(
1964 &*self.detailed_summary_rx.borrow(),
1965 DetailedSummaryState::Generating { .. }
1966 )
1967 }
1968
1969 pub fn use_pending_tools(
1970 &mut self,
1971 window: Option<AnyWindowHandle>,
1972 cx: &mut Context<Self>,
1973 model: Arc<dyn LanguageModel>,
1974 ) -> Vec<PendingToolUse> {
1975 self.auto_capture_telemetry(cx);
1976 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1977 let pending_tool_uses = self
1978 .tool_use
1979 .pending_tool_uses()
1980 .into_iter()
1981 .filter(|tool_use| tool_use.status.is_idle())
1982 .cloned()
1983 .collect::<Vec<_>>();
1984
1985 for tool_use in pending_tool_uses.iter() {
1986 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1987 if tool.needs_confirmation(&tool_use.input, cx)
1988 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1989 {
1990 self.tool_use.confirm_tool_use(
1991 tool_use.id.clone(),
1992 tool_use.ui_text.clone(),
1993 tool_use.input.clone(),
1994 request.clone(),
1995 tool,
1996 );
1997 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1998 } else {
1999 self.run_tool(
2000 tool_use.id.clone(),
2001 tool_use.ui_text.clone(),
2002 tool_use.input.clone(),
2003 request.clone(),
2004 tool,
2005 model.clone(),
2006 window,
2007 cx,
2008 );
2009 }
2010 } else {
2011 self.handle_hallucinated_tool_use(
2012 tool_use.id.clone(),
2013 tool_use.name.clone(),
2014 window,
2015 cx,
2016 );
2017 }
2018 }
2019
2020 pending_tool_uses
2021 }
2022
2023 pub fn handle_hallucinated_tool_use(
2024 &mut self,
2025 tool_use_id: LanguageModelToolUseId,
2026 hallucinated_tool_name: Arc<str>,
2027 window: Option<AnyWindowHandle>,
2028 cx: &mut Context<Thread>,
2029 ) {
2030 let available_tools = self.tools.read(cx).enabled_tools(cx);
2031
2032 let tool_list = available_tools
2033 .iter()
2034 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2035 .collect::<Vec<_>>()
2036 .join("\n");
2037
2038 let error_message = format!(
2039 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2040 hallucinated_tool_name, tool_list
2041 );
2042
2043 let pending_tool_use = self.tool_use.insert_tool_output(
2044 tool_use_id.clone(),
2045 hallucinated_tool_name,
2046 Err(anyhow!("Missing tool call: {error_message}")),
2047 self.configured_model.as_ref(),
2048 );
2049
2050 cx.emit(ThreadEvent::MissingToolUse {
2051 tool_use_id: tool_use_id.clone(),
2052 ui_text: error_message.into(),
2053 });
2054
2055 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2056 }
2057
2058 pub fn receive_invalid_tool_json(
2059 &mut self,
2060 tool_use_id: LanguageModelToolUseId,
2061 tool_name: Arc<str>,
2062 invalid_json: Arc<str>,
2063 error: String,
2064 window: Option<AnyWindowHandle>,
2065 cx: &mut Context<Thread>,
2066 ) {
2067 log::error!("The model returned invalid input JSON: {invalid_json}");
2068
2069 let pending_tool_use = self.tool_use.insert_tool_output(
2070 tool_use_id.clone(),
2071 tool_name,
2072 Err(anyhow!("Error parsing input JSON: {error}")),
2073 self.configured_model.as_ref(),
2074 );
2075 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2076 pending_tool_use.ui_text.clone()
2077 } else {
2078 log::error!(
2079 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2080 );
2081 format!("Unknown tool {}", tool_use_id).into()
2082 };
2083
2084 cx.emit(ThreadEvent::InvalidToolInput {
2085 tool_use_id: tool_use_id.clone(),
2086 ui_text,
2087 invalid_input_json: invalid_json,
2088 });
2089
2090 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2091 }
2092
2093 pub fn run_tool(
2094 &mut self,
2095 tool_use_id: LanguageModelToolUseId,
2096 ui_text: impl Into<SharedString>,
2097 input: serde_json::Value,
2098 request: Arc<LanguageModelRequest>,
2099 tool: Arc<dyn Tool>,
2100 model: Arc<dyn LanguageModel>,
2101 window: Option<AnyWindowHandle>,
2102 cx: &mut Context<Thread>,
2103 ) {
2104 let task =
2105 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2106 self.tool_use
2107 .run_pending_tool(tool_use_id, ui_text.into(), task);
2108 }
2109
2110 fn spawn_tool_use(
2111 &mut self,
2112 tool_use_id: LanguageModelToolUseId,
2113 request: Arc<LanguageModelRequest>,
2114 input: serde_json::Value,
2115 tool: Arc<dyn Tool>,
2116 model: Arc<dyn LanguageModel>,
2117 window: Option<AnyWindowHandle>,
2118 cx: &mut Context<Thread>,
2119 ) -> Task<()> {
2120 let tool_name: Arc<str> = tool.name().into();
2121
2122 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2123 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2124 } else {
2125 tool.run(
2126 input,
2127 request,
2128 self.project.clone(),
2129 self.action_log.clone(),
2130 model,
2131 window,
2132 cx,
2133 )
2134 };
2135
2136 // Store the card separately if it exists
2137 if let Some(card) = tool_result.card.clone() {
2138 self.tool_use
2139 .insert_tool_result_card(tool_use_id.clone(), card);
2140 }
2141
2142 cx.spawn({
2143 async move |thread: WeakEntity<Thread>, cx| {
2144 let output = tool_result.output.await;
2145
2146 thread
2147 .update(cx, |thread, cx| {
2148 let pending_tool_use = thread.tool_use.insert_tool_output(
2149 tool_use_id.clone(),
2150 tool_name,
2151 output,
2152 thread.configured_model.as_ref(),
2153 );
2154 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2155 })
2156 .ok();
2157 }
2158 })
2159 }
2160
2161 fn tool_finished(
2162 &mut self,
2163 tool_use_id: LanguageModelToolUseId,
2164 pending_tool_use: Option<PendingToolUse>,
2165 canceled: bool,
2166 window: Option<AnyWindowHandle>,
2167 cx: &mut Context<Self>,
2168 ) {
2169 if self.all_tools_finished() {
2170 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2171 if !canceled {
2172 self.send_to_model(model.clone(), window, cx);
2173 }
2174 self.auto_capture_telemetry(cx);
2175 }
2176 }
2177
2178 cx.emit(ThreadEvent::ToolFinished {
2179 tool_use_id,
2180 pending_tool_use,
2181 });
2182 }
2183
2184 /// Cancels the last pending completion, if there are any pending.
2185 ///
2186 /// Returns whether a completion was canceled.
2187 pub fn cancel_last_completion(
2188 &mut self,
2189 window: Option<AnyWindowHandle>,
2190 cx: &mut Context<Self>,
2191 ) -> bool {
2192 let mut canceled = self.pending_completions.pop().is_some();
2193
2194 for pending_tool_use in self.tool_use.cancel_pending() {
2195 canceled = true;
2196 self.tool_finished(
2197 pending_tool_use.id.clone(),
2198 Some(pending_tool_use),
2199 true,
2200 window,
2201 cx,
2202 );
2203 }
2204
2205 self.finalize_pending_checkpoint(cx);
2206
2207 if canceled {
2208 cx.emit(ThreadEvent::CompletionCanceled);
2209 }
2210
2211 canceled
2212 }
2213
2214 /// Signals that any in-progress editing should be canceled.
2215 ///
2216 /// This method is used to notify listeners (like ActiveThread) that
2217 /// they should cancel any editing operations.
2218 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2219 cx.emit(ThreadEvent::CancelEditing);
2220 }
2221
2222 pub fn feedback(&self) -> Option<ThreadFeedback> {
2223 self.feedback
2224 }
2225
2226 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2227 self.message_feedback.get(&message_id).copied()
2228 }
2229
2230 pub fn report_message_feedback(
2231 &mut self,
2232 message_id: MessageId,
2233 feedback: ThreadFeedback,
2234 cx: &mut Context<Self>,
2235 ) -> Task<Result<()>> {
2236 if self.message_feedback.get(&message_id) == Some(&feedback) {
2237 return Task::ready(Ok(()));
2238 }
2239
2240 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2241 let serialized_thread = self.serialize(cx);
2242 let thread_id = self.id().clone();
2243 let client = self.project.read(cx).client();
2244
2245 let enabled_tool_names: Vec<String> = self
2246 .tools()
2247 .read(cx)
2248 .enabled_tools(cx)
2249 .iter()
2250 .map(|tool| tool.name())
2251 .collect();
2252
2253 self.message_feedback.insert(message_id, feedback);
2254
2255 cx.notify();
2256
2257 let message_content = self
2258 .message(message_id)
2259 .map(|msg| msg.to_string())
2260 .unwrap_or_default();
2261
2262 cx.background_spawn(async move {
2263 let final_project_snapshot = final_project_snapshot.await;
2264 let serialized_thread = serialized_thread.await?;
2265 let thread_data =
2266 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2267
2268 let rating = match feedback {
2269 ThreadFeedback::Positive => "positive",
2270 ThreadFeedback::Negative => "negative",
2271 };
2272 telemetry::event!(
2273 "Assistant Thread Rated",
2274 rating,
2275 thread_id,
2276 enabled_tool_names,
2277 message_id = message_id.0,
2278 message_content,
2279 thread_data,
2280 final_project_snapshot
2281 );
2282 client.telemetry().flush_events().await;
2283
2284 Ok(())
2285 })
2286 }
2287
2288 pub fn report_feedback(
2289 &mut self,
2290 feedback: ThreadFeedback,
2291 cx: &mut Context<Self>,
2292 ) -> Task<Result<()>> {
2293 let last_assistant_message_id = self
2294 .messages
2295 .iter()
2296 .rev()
2297 .find(|msg| msg.role == Role::Assistant)
2298 .map(|msg| msg.id);
2299
2300 if let Some(message_id) = last_assistant_message_id {
2301 self.report_message_feedback(message_id, feedback, cx)
2302 } else {
2303 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2304 let serialized_thread = self.serialize(cx);
2305 let thread_id = self.id().clone();
2306 let client = self.project.read(cx).client();
2307 self.feedback = Some(feedback);
2308 cx.notify();
2309
2310 cx.background_spawn(async move {
2311 let final_project_snapshot = final_project_snapshot.await;
2312 let serialized_thread = serialized_thread.await?;
2313 let thread_data = serde_json::to_value(serialized_thread)
2314 .unwrap_or_else(|_| serde_json::Value::Null);
2315
2316 let rating = match feedback {
2317 ThreadFeedback::Positive => "positive",
2318 ThreadFeedback::Negative => "negative",
2319 };
2320 telemetry::event!(
2321 "Assistant Thread Rated",
2322 rating,
2323 thread_id,
2324 thread_data,
2325 final_project_snapshot
2326 );
2327 client.telemetry().flush_events().await;
2328
2329 Ok(())
2330 })
2331 }
2332 }
2333
2334 /// Create a snapshot of the current project state including git information and unsaved buffers.
2335 fn project_snapshot(
2336 project: Entity<Project>,
2337 cx: &mut Context<Self>,
2338 ) -> Task<Arc<ProjectSnapshot>> {
2339 let git_store = project.read(cx).git_store().clone();
2340 let worktree_snapshots: Vec<_> = project
2341 .read(cx)
2342 .visible_worktrees(cx)
2343 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2344 .collect();
2345
2346 cx.spawn(async move |_, cx| {
2347 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2348
2349 let mut unsaved_buffers = Vec::new();
2350 cx.update(|app_cx| {
2351 let buffer_store = project.read(app_cx).buffer_store();
2352 for buffer_handle in buffer_store.read(app_cx).buffers() {
2353 let buffer = buffer_handle.read(app_cx);
2354 if buffer.is_dirty() {
2355 if let Some(file) = buffer.file() {
2356 let path = file.path().to_string_lossy().to_string();
2357 unsaved_buffers.push(path);
2358 }
2359 }
2360 }
2361 })
2362 .ok();
2363
2364 Arc::new(ProjectSnapshot {
2365 worktree_snapshots,
2366 unsaved_buffer_paths: unsaved_buffers,
2367 timestamp: Utc::now(),
2368 })
2369 })
2370 }
2371
2372 fn worktree_snapshot(
2373 worktree: Entity<project::Worktree>,
2374 git_store: Entity<GitStore>,
2375 cx: &App,
2376 ) -> Task<WorktreeSnapshot> {
2377 cx.spawn(async move |cx| {
2378 // Get worktree path and snapshot
2379 let worktree_info = cx.update(|app_cx| {
2380 let worktree = worktree.read(app_cx);
2381 let path = worktree.abs_path().to_string_lossy().to_string();
2382 let snapshot = worktree.snapshot();
2383 (path, snapshot)
2384 });
2385
2386 let Ok((worktree_path, _snapshot)) = worktree_info else {
2387 return WorktreeSnapshot {
2388 worktree_path: String::new(),
2389 git_state: None,
2390 };
2391 };
2392
2393 let git_state = git_store
2394 .update(cx, |git_store, cx| {
2395 git_store
2396 .repositories()
2397 .values()
2398 .find(|repo| {
2399 repo.read(cx)
2400 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2401 .is_some()
2402 })
2403 .cloned()
2404 })
2405 .ok()
2406 .flatten()
2407 .map(|repo| {
2408 repo.update(cx, |repo, _| {
2409 let current_branch =
2410 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2411 repo.send_job(None, |state, _| async move {
2412 let RepositoryState::Local { backend, .. } = state else {
2413 return GitState {
2414 remote_url: None,
2415 head_sha: None,
2416 current_branch,
2417 diff: None,
2418 };
2419 };
2420
2421 let remote_url = backend.remote_url("origin");
2422 let head_sha = backend.head_sha().await;
2423 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2424
2425 GitState {
2426 remote_url,
2427 head_sha,
2428 current_branch,
2429 diff,
2430 }
2431 })
2432 })
2433 });
2434
2435 let git_state = match git_state {
2436 Some(git_state) => match git_state.ok() {
2437 Some(git_state) => git_state.await.ok(),
2438 None => None,
2439 },
2440 None => None,
2441 };
2442
2443 WorktreeSnapshot {
2444 worktree_path,
2445 git_state,
2446 }
2447 })
2448 }
2449
2450 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2451 let mut markdown = Vec::new();
2452
2453 let summary = self.summary().or_default();
2454 writeln!(markdown, "# {summary}\n")?;
2455
2456 for message in self.messages() {
2457 writeln!(
2458 markdown,
2459 "## {role}\n",
2460 role = match message.role {
2461 Role::User => "User",
2462 Role::Assistant => "Agent",
2463 Role::System => "System",
2464 }
2465 )?;
2466
2467 if !message.loaded_context.text.is_empty() {
2468 writeln!(markdown, "{}", message.loaded_context.text)?;
2469 }
2470
2471 if !message.loaded_context.images.is_empty() {
2472 writeln!(
2473 markdown,
2474 "\n{} images attached as context.\n",
2475 message.loaded_context.images.len()
2476 )?;
2477 }
2478
2479 for segment in &message.segments {
2480 match segment {
2481 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2482 MessageSegment::Thinking { text, .. } => {
2483 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2484 }
2485 MessageSegment::RedactedThinking(_) => {}
2486 }
2487 }
2488
2489 for tool_use in self.tool_uses_for_message(message.id, cx) {
2490 writeln!(
2491 markdown,
2492 "**Use Tool: {} ({})**",
2493 tool_use.name, tool_use.id
2494 )?;
2495 writeln!(markdown, "```json")?;
2496 writeln!(
2497 markdown,
2498 "{}",
2499 serde_json::to_string_pretty(&tool_use.input)?
2500 )?;
2501 writeln!(markdown, "```")?;
2502 }
2503
2504 for tool_result in self.tool_results_for_message(message.id) {
2505 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2506 if tool_result.is_error {
2507 write!(markdown, " (Error)")?;
2508 }
2509
2510 writeln!(markdown, "**\n")?;
2511 match &tool_result.content {
2512 LanguageModelToolResultContent::Text(str) => {
2513 writeln!(markdown, "{}", str)?;
2514 }
2515 LanguageModelToolResultContent::Image(image) => {
2516 writeln!(markdown, "", image.source)?;
2517 }
2518 }
2519
2520 if let Some(output) = tool_result.output.as_ref() {
2521 writeln!(
2522 markdown,
2523 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2524 serde_json::to_string_pretty(output)?
2525 )?;
2526 }
2527 }
2528 }
2529
2530 Ok(String::from_utf8_lossy(&markdown).to_string())
2531 }
2532
2533 pub fn keep_edits_in_range(
2534 &mut self,
2535 buffer: Entity<language::Buffer>,
2536 buffer_range: Range<language::Anchor>,
2537 cx: &mut Context<Self>,
2538 ) {
2539 self.action_log.update(cx, |action_log, cx| {
2540 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2541 });
2542 }
2543
2544 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2545 self.action_log
2546 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2547 }
2548
2549 pub fn reject_edits_in_ranges(
2550 &mut self,
2551 buffer: Entity<language::Buffer>,
2552 buffer_ranges: Vec<Range<language::Anchor>>,
2553 cx: &mut Context<Self>,
2554 ) -> Task<Result<()>> {
2555 self.action_log.update(cx, |action_log, cx| {
2556 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2557 })
2558 }
2559
2560 pub fn action_log(&self) -> &Entity<ActionLog> {
2561 &self.action_log
2562 }
2563
2564 pub fn project(&self) -> &Entity<Project> {
2565 &self.project
2566 }
2567
2568 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2569 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2570 return;
2571 }
2572
2573 let now = Instant::now();
2574 if let Some(last) = self.last_auto_capture_at {
2575 if now.duration_since(last).as_secs() < 10 {
2576 return;
2577 }
2578 }
2579
2580 self.last_auto_capture_at = Some(now);
2581
2582 let thread_id = self.id().clone();
2583 let github_login = self
2584 .project
2585 .read(cx)
2586 .user_store()
2587 .read(cx)
2588 .current_user()
2589 .map(|user| user.github_login.clone());
2590 let client = self.project.read(cx).client().clone();
2591 let serialize_task = self.serialize(cx);
2592
2593 cx.background_executor()
2594 .spawn(async move {
2595 if let Ok(serialized_thread) = serialize_task.await {
2596 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2597 telemetry::event!(
2598 "Agent Thread Auto-Captured",
2599 thread_id = thread_id.to_string(),
2600 thread_data = thread_data,
2601 auto_capture_reason = "tracked_user",
2602 github_login = github_login
2603 );
2604
2605 client.telemetry().flush_events().await;
2606 }
2607 }
2608 })
2609 .detach();
2610 }
2611
2612 pub fn cumulative_token_usage(&self) -> TokenUsage {
2613 self.cumulative_token_usage
2614 }
2615
2616 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2617 let Some(model) = self.configured_model.as_ref() else {
2618 return TotalTokenUsage::default();
2619 };
2620
2621 let max = model.model.max_token_count();
2622
2623 let index = self
2624 .messages
2625 .iter()
2626 .position(|msg| msg.id == message_id)
2627 .unwrap_or(0);
2628
2629 if index == 0 {
2630 return TotalTokenUsage { total: 0, max };
2631 }
2632
2633 let token_usage = &self
2634 .request_token_usage
2635 .get(index - 1)
2636 .cloned()
2637 .unwrap_or_default();
2638
2639 TotalTokenUsage {
2640 total: token_usage.total_tokens() as usize,
2641 max,
2642 }
2643 }
2644
2645 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2646 let model = self.configured_model.as_ref()?;
2647
2648 let max = model.model.max_token_count();
2649
2650 if let Some(exceeded_error) = &self.exceeded_window_error {
2651 if model.model.id() == exceeded_error.model_id {
2652 return Some(TotalTokenUsage {
2653 total: exceeded_error.token_count,
2654 max,
2655 });
2656 }
2657 }
2658
2659 let total = self
2660 .token_usage_at_last_message()
2661 .unwrap_or_default()
2662 .total_tokens() as usize;
2663
2664 Some(TotalTokenUsage { total, max })
2665 }
2666
2667 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2668 self.request_token_usage
2669 .get(self.messages.len().saturating_sub(1))
2670 .or_else(|| self.request_token_usage.last())
2671 .cloned()
2672 }
2673
2674 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2675 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2676 self.request_token_usage
2677 .resize(self.messages.len(), placeholder);
2678
2679 if let Some(last) = self.request_token_usage.last_mut() {
2680 *last = token_usage;
2681 }
2682 }
2683
2684 pub fn deny_tool_use(
2685 &mut self,
2686 tool_use_id: LanguageModelToolUseId,
2687 tool_name: Arc<str>,
2688 window: Option<AnyWindowHandle>,
2689 cx: &mut Context<Self>,
2690 ) {
2691 let err = Err(anyhow::anyhow!(
2692 "Permission to run tool action denied by user"
2693 ));
2694
2695 self.tool_use.insert_tool_output(
2696 tool_use_id.clone(),
2697 tool_name,
2698 err,
2699 self.configured_model.as_ref(),
2700 );
2701 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2702 }
2703}
2704
2705#[derive(Debug, Clone, Error)]
2706pub enum ThreadError {
2707 #[error("Payment required")]
2708 PaymentRequired,
2709 #[error("Max monthly spend reached")]
2710 MaxMonthlySpendReached,
2711 #[error("Model request limit reached")]
2712 ModelRequestLimitReached { plan: Plan },
2713 #[error("Message {header}: {message}")]
2714 Message {
2715 header: SharedString,
2716 message: SharedString,
2717 },
2718}
2719
2720#[derive(Debug, Clone)]
2721pub enum ThreadEvent {
2722 ShowError(ThreadError),
2723 StreamedCompletion,
2724 ReceivedTextChunk,
2725 NewRequest,
2726 StreamedAssistantText(MessageId, String),
2727 StreamedAssistantThinking(MessageId, String),
2728 StreamedToolUse {
2729 tool_use_id: LanguageModelToolUseId,
2730 ui_text: Arc<str>,
2731 input: serde_json::Value,
2732 },
2733 MissingToolUse {
2734 tool_use_id: LanguageModelToolUseId,
2735 ui_text: Arc<str>,
2736 },
2737 InvalidToolInput {
2738 tool_use_id: LanguageModelToolUseId,
2739 ui_text: Arc<str>,
2740 invalid_input_json: Arc<str>,
2741 },
2742 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2743 MessageAdded(MessageId),
2744 MessageEdited(MessageId),
2745 MessageDeleted(MessageId),
2746 SummaryGenerated,
2747 SummaryChanged,
2748 UsePendingTools {
2749 tool_uses: Vec<PendingToolUse>,
2750 },
2751 ToolFinished {
2752 #[allow(unused)]
2753 tool_use_id: LanguageModelToolUseId,
2754 /// The pending tool use that corresponds to this tool.
2755 pending_tool_use: Option<PendingToolUse>,
2756 },
2757 CheckpointChanged,
2758 ToolConfirmationNeeded,
2759 CancelEditing,
2760 CompletionCanceled,
2761}
2762
2763impl EventEmitter<ThreadEvent> for Thread {}
2764
2765struct PendingCompletion {
2766 id: usize,
2767 queue_state: QueueState,
2768 _task: Task<()>,
2769}
2770
2771#[cfg(test)]
2772mod tests {
2773 use super::*;
2774 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2775 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2776 use assistant_tool::ToolRegistry;
2777 use editor::EditorSettings;
2778 use gpui::TestAppContext;
2779 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2780 use project::{FakeFs, Project};
2781 use prompt_store::PromptBuilder;
2782 use serde_json::json;
2783 use settings::{Settings, SettingsStore};
2784 use std::sync::Arc;
2785 use theme::ThemeSettings;
2786 use util::path;
2787 use workspace::Workspace;
2788
2789 #[gpui::test]
2790 async fn test_message_with_context(cx: &mut TestAppContext) {
2791 init_test_settings(cx);
2792
2793 let project = create_test_project(
2794 cx,
2795 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2796 )
2797 .await;
2798
2799 let (_workspace, _thread_store, thread, context_store, model) =
2800 setup_test_environment(cx, project.clone()).await;
2801
2802 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2803 .await
2804 .unwrap();
2805
2806 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2807 let loaded_context = cx
2808 .update(|cx| load_context(vec![context], &project, &None, cx))
2809 .await;
2810
2811 // Insert user message with context
2812 let message_id = thread.update(cx, |thread, cx| {
2813 thread.insert_user_message(
2814 "Please explain this code",
2815 loaded_context,
2816 None,
2817 Vec::new(),
2818 cx,
2819 )
2820 });
2821
2822 // Check content and context in message object
2823 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2824
2825 // Use different path format strings based on platform for the test
2826 #[cfg(windows)]
2827 let path_part = r"test\code.rs";
2828 #[cfg(not(windows))]
2829 let path_part = "test/code.rs";
2830
2831 let expected_context = format!(
2832 r#"
2833<context>
2834The following items were attached by the user. They are up-to-date and don't need to be re-read.
2835
2836<files>
2837```rs {path_part}
2838fn main() {{
2839 println!("Hello, world!");
2840}}
2841```
2842</files>
2843</context>
2844"#
2845 );
2846
2847 assert_eq!(message.role, Role::User);
2848 assert_eq!(message.segments.len(), 1);
2849 assert_eq!(
2850 message.segments[0],
2851 MessageSegment::Text("Please explain this code".to_string())
2852 );
2853 assert_eq!(message.loaded_context.text, expected_context);
2854
2855 // Check message in request
2856 let request = thread.update(cx, |thread, cx| {
2857 thread.to_completion_request(model.clone(), cx)
2858 });
2859
2860 assert_eq!(request.messages.len(), 2);
2861 let expected_full_message = format!("{}Please explain this code", expected_context);
2862 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2863 }
2864
2865 #[gpui::test]
2866 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2867 init_test_settings(cx);
2868
2869 let project = create_test_project(
2870 cx,
2871 json!({
2872 "file1.rs": "fn function1() {}\n",
2873 "file2.rs": "fn function2() {}\n",
2874 "file3.rs": "fn function3() {}\n",
2875 "file4.rs": "fn function4() {}\n",
2876 }),
2877 )
2878 .await;
2879
2880 let (_, _thread_store, thread, context_store, model) =
2881 setup_test_environment(cx, project.clone()).await;
2882
2883 // First message with context 1
2884 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2885 .await
2886 .unwrap();
2887 let new_contexts = context_store.update(cx, |store, cx| {
2888 store.new_context_for_thread(thread.read(cx), None)
2889 });
2890 assert_eq!(new_contexts.len(), 1);
2891 let loaded_context = cx
2892 .update(|cx| load_context(new_contexts, &project, &None, cx))
2893 .await;
2894 let message1_id = thread.update(cx, |thread, cx| {
2895 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2896 });
2897
2898 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2899 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2900 .await
2901 .unwrap();
2902 let new_contexts = context_store.update(cx, |store, cx| {
2903 store.new_context_for_thread(thread.read(cx), None)
2904 });
2905 assert_eq!(new_contexts.len(), 1);
2906 let loaded_context = cx
2907 .update(|cx| load_context(new_contexts, &project, &None, cx))
2908 .await;
2909 let message2_id = thread.update(cx, |thread, cx| {
2910 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2911 });
2912
2913 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2914 //
2915 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2916 .await
2917 .unwrap();
2918 let new_contexts = context_store.update(cx, |store, cx| {
2919 store.new_context_for_thread(thread.read(cx), None)
2920 });
2921 assert_eq!(new_contexts.len(), 1);
2922 let loaded_context = cx
2923 .update(|cx| load_context(new_contexts, &project, &None, cx))
2924 .await;
2925 let message3_id = thread.update(cx, |thread, cx| {
2926 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2927 });
2928
2929 // Check what contexts are included in each message
2930 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2931 (
2932 thread.message(message1_id).unwrap().clone(),
2933 thread.message(message2_id).unwrap().clone(),
2934 thread.message(message3_id).unwrap().clone(),
2935 )
2936 });
2937
2938 // First message should include context 1
2939 assert!(message1.loaded_context.text.contains("file1.rs"));
2940
2941 // Second message should include only context 2 (not 1)
2942 assert!(!message2.loaded_context.text.contains("file1.rs"));
2943 assert!(message2.loaded_context.text.contains("file2.rs"));
2944
2945 // Third message should include only context 3 (not 1 or 2)
2946 assert!(!message3.loaded_context.text.contains("file1.rs"));
2947 assert!(!message3.loaded_context.text.contains("file2.rs"));
2948 assert!(message3.loaded_context.text.contains("file3.rs"));
2949
2950 // Check entire request to make sure all contexts are properly included
2951 let request = thread.update(cx, |thread, cx| {
2952 thread.to_completion_request(model.clone(), cx)
2953 });
2954
2955 // The request should contain all 3 messages
2956 assert_eq!(request.messages.len(), 4);
2957
2958 // Check that the contexts are properly formatted in each message
2959 assert!(request.messages[1].string_contents().contains("file1.rs"));
2960 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2961 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2962
2963 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2964 assert!(request.messages[2].string_contents().contains("file2.rs"));
2965 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2966
2967 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2968 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2969 assert!(request.messages[3].string_contents().contains("file3.rs"));
2970
2971 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2972 .await
2973 .unwrap();
2974 let new_contexts = context_store.update(cx, |store, cx| {
2975 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2976 });
2977 assert_eq!(new_contexts.len(), 3);
2978 let loaded_context = cx
2979 .update(|cx| load_context(new_contexts, &project, &None, cx))
2980 .await
2981 .loaded_context;
2982
2983 assert!(!loaded_context.text.contains("file1.rs"));
2984 assert!(loaded_context.text.contains("file2.rs"));
2985 assert!(loaded_context.text.contains("file3.rs"));
2986 assert!(loaded_context.text.contains("file4.rs"));
2987
2988 let new_contexts = context_store.update(cx, |store, cx| {
2989 // Remove file4.rs
2990 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2991 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2992 });
2993 assert_eq!(new_contexts.len(), 2);
2994 let loaded_context = cx
2995 .update(|cx| load_context(new_contexts, &project, &None, cx))
2996 .await
2997 .loaded_context;
2998
2999 assert!(!loaded_context.text.contains("file1.rs"));
3000 assert!(loaded_context.text.contains("file2.rs"));
3001 assert!(loaded_context.text.contains("file3.rs"));
3002 assert!(!loaded_context.text.contains("file4.rs"));
3003
3004 let new_contexts = context_store.update(cx, |store, cx| {
3005 // Remove file3.rs
3006 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3007 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3008 });
3009 assert_eq!(new_contexts.len(), 1);
3010 let loaded_context = cx
3011 .update(|cx| load_context(new_contexts, &project, &None, cx))
3012 .await
3013 .loaded_context;
3014
3015 assert!(!loaded_context.text.contains("file1.rs"));
3016 assert!(loaded_context.text.contains("file2.rs"));
3017 assert!(!loaded_context.text.contains("file3.rs"));
3018 assert!(!loaded_context.text.contains("file4.rs"));
3019 }
3020
3021 #[gpui::test]
3022 async fn test_message_without_files(cx: &mut TestAppContext) {
3023 init_test_settings(cx);
3024
3025 let project = create_test_project(
3026 cx,
3027 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3028 )
3029 .await;
3030
3031 let (_, _thread_store, thread, _context_store, model) =
3032 setup_test_environment(cx, project.clone()).await;
3033
3034 // Insert user message without any context (empty context vector)
3035 let message_id = thread.update(cx, |thread, cx| {
3036 thread.insert_user_message(
3037 "What is the best way to learn Rust?",
3038 ContextLoadResult::default(),
3039 None,
3040 Vec::new(),
3041 cx,
3042 )
3043 });
3044
3045 // Check content and context in message object
3046 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3047
3048 // Context should be empty when no files are included
3049 assert_eq!(message.role, Role::User);
3050 assert_eq!(message.segments.len(), 1);
3051 assert_eq!(
3052 message.segments[0],
3053 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3054 );
3055 assert_eq!(message.loaded_context.text, "");
3056
3057 // Check message in request
3058 let request = thread.update(cx, |thread, cx| {
3059 thread.to_completion_request(model.clone(), cx)
3060 });
3061
3062 assert_eq!(request.messages.len(), 2);
3063 assert_eq!(
3064 request.messages[1].string_contents(),
3065 "What is the best way to learn Rust?"
3066 );
3067
3068 // Add second message, also without context
3069 let message2_id = thread.update(cx, |thread, cx| {
3070 thread.insert_user_message(
3071 "Are there any good books?",
3072 ContextLoadResult::default(),
3073 None,
3074 Vec::new(),
3075 cx,
3076 )
3077 });
3078
3079 let message2 =
3080 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3081 assert_eq!(message2.loaded_context.text, "");
3082
3083 // Check that both messages appear in the request
3084 let request = thread.update(cx, |thread, cx| {
3085 thread.to_completion_request(model.clone(), cx)
3086 });
3087
3088 assert_eq!(request.messages.len(), 3);
3089 assert_eq!(
3090 request.messages[1].string_contents(),
3091 "What is the best way to learn Rust?"
3092 );
3093 assert_eq!(
3094 request.messages[2].string_contents(),
3095 "Are there any good books?"
3096 );
3097 }
3098
3099 #[gpui::test]
3100 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3101 init_test_settings(cx);
3102
3103 let project = create_test_project(
3104 cx,
3105 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3106 )
3107 .await;
3108
3109 let (_workspace, _thread_store, thread, context_store, model) =
3110 setup_test_environment(cx, project.clone()).await;
3111
3112 // Open buffer and add it to context
3113 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3114 .await
3115 .unwrap();
3116
3117 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3118 let loaded_context = cx
3119 .update(|cx| load_context(vec![context], &project, &None, cx))
3120 .await;
3121
3122 // Insert user message with the buffer as context
3123 thread.update(cx, |thread, cx| {
3124 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3125 });
3126
3127 // Create a request and check that it doesn't have a stale buffer warning yet
3128 let initial_request = thread.update(cx, |thread, cx| {
3129 thread.to_completion_request(model.clone(), cx)
3130 });
3131
3132 // Make sure we don't have a stale file warning yet
3133 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3134 msg.string_contents()
3135 .contains("These files changed since last read:")
3136 });
3137 assert!(
3138 !has_stale_warning,
3139 "Should not have stale buffer warning before buffer is modified"
3140 );
3141
3142 // Modify the buffer
3143 buffer.update(cx, |buffer, cx| {
3144 // Find a position at the end of line 1
3145 buffer.edit(
3146 [(1..1, "\n println!(\"Added a new line\");\n")],
3147 None,
3148 cx,
3149 );
3150 });
3151
3152 // Insert another user message without context
3153 thread.update(cx, |thread, cx| {
3154 thread.insert_user_message(
3155 "What does the code do now?",
3156 ContextLoadResult::default(),
3157 None,
3158 Vec::new(),
3159 cx,
3160 )
3161 });
3162
3163 // Create a new request and check for the stale buffer warning
3164 let new_request = thread.update(cx, |thread, cx| {
3165 thread.to_completion_request(model.clone(), cx)
3166 });
3167
3168 // We should have a stale file warning as the last message
3169 let last_message = new_request
3170 .messages
3171 .last()
3172 .expect("Request should have messages");
3173
3174 // The last message should be the stale buffer notification
3175 assert_eq!(last_message.role, Role::User);
3176
3177 // Check the exact content of the message
3178 let expected_content = "These files changed since last read:\n- code.rs\n";
3179 assert_eq!(
3180 last_message.string_contents(),
3181 expected_content,
3182 "Last message should be exactly the stale buffer notification"
3183 );
3184 }
3185
3186 #[gpui::test]
3187 async fn test_temperature_setting(cx: &mut TestAppContext) {
3188 init_test_settings(cx);
3189
3190 let project = create_test_project(
3191 cx,
3192 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3193 )
3194 .await;
3195
3196 let (_workspace, _thread_store, thread, _context_store, model) =
3197 setup_test_environment(cx, project.clone()).await;
3198
3199 // Both model and provider
3200 cx.update(|cx| {
3201 AssistantSettings::override_global(
3202 AssistantSettings {
3203 model_parameters: vec![LanguageModelParameters {
3204 provider: Some(model.provider_id().0.to_string().into()),
3205 model: Some(model.id().0.clone()),
3206 temperature: Some(0.66),
3207 }],
3208 ..AssistantSettings::get_global(cx).clone()
3209 },
3210 cx,
3211 );
3212 });
3213
3214 let request = thread.update(cx, |thread, cx| {
3215 thread.to_completion_request(model.clone(), cx)
3216 });
3217 assert_eq!(request.temperature, Some(0.66));
3218
3219 // Only model
3220 cx.update(|cx| {
3221 AssistantSettings::override_global(
3222 AssistantSettings {
3223 model_parameters: vec![LanguageModelParameters {
3224 provider: None,
3225 model: Some(model.id().0.clone()),
3226 temperature: Some(0.66),
3227 }],
3228 ..AssistantSettings::get_global(cx).clone()
3229 },
3230 cx,
3231 );
3232 });
3233
3234 let request = thread.update(cx, |thread, cx| {
3235 thread.to_completion_request(model.clone(), cx)
3236 });
3237 assert_eq!(request.temperature, Some(0.66));
3238
3239 // Only provider
3240 cx.update(|cx| {
3241 AssistantSettings::override_global(
3242 AssistantSettings {
3243 model_parameters: vec![LanguageModelParameters {
3244 provider: Some(model.provider_id().0.to_string().into()),
3245 model: None,
3246 temperature: Some(0.66),
3247 }],
3248 ..AssistantSettings::get_global(cx).clone()
3249 },
3250 cx,
3251 );
3252 });
3253
3254 let request = thread.update(cx, |thread, cx| {
3255 thread.to_completion_request(model.clone(), cx)
3256 });
3257 assert_eq!(request.temperature, Some(0.66));
3258
3259 // Same model name, different provider
3260 cx.update(|cx| {
3261 AssistantSettings::override_global(
3262 AssistantSettings {
3263 model_parameters: vec![LanguageModelParameters {
3264 provider: Some("anthropic".into()),
3265 model: Some(model.id().0.clone()),
3266 temperature: Some(0.66),
3267 }],
3268 ..AssistantSettings::get_global(cx).clone()
3269 },
3270 cx,
3271 );
3272 });
3273
3274 let request = thread.update(cx, |thread, cx| {
3275 thread.to_completion_request(model.clone(), cx)
3276 });
3277 assert_eq!(request.temperature, None);
3278 }
3279
3280 #[gpui::test]
3281 async fn test_thread_summary(cx: &mut TestAppContext) {
3282 init_test_settings(cx);
3283
3284 let project = create_test_project(cx, json!({})).await;
3285
3286 let (_, _thread_store, thread, _context_store, model) =
3287 setup_test_environment(cx, project.clone()).await;
3288
3289 // Initial state should be pending
3290 thread.read_with(cx, |thread, _| {
3291 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3292 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3293 });
3294
3295 // Manually setting the summary should not be allowed in this state
3296 thread.update(cx, |thread, cx| {
3297 thread.set_summary("This should not work", cx);
3298 });
3299
3300 thread.read_with(cx, |thread, _| {
3301 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3302 });
3303
3304 // Send a message
3305 thread.update(cx, |thread, cx| {
3306 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3307 thread.send_to_model(model.clone(), None, cx);
3308 });
3309
3310 let fake_model = model.as_fake();
3311 simulate_successful_response(&fake_model, cx);
3312
3313 // Should start generating summary when there are >= 2 messages
3314 thread.read_with(cx, |thread, _| {
3315 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3316 });
3317
3318 // Should not be able to set the summary while generating
3319 thread.update(cx, |thread, cx| {
3320 thread.set_summary("This should not work either", cx);
3321 });
3322
3323 thread.read_with(cx, |thread, _| {
3324 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3325 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3326 });
3327
3328 cx.run_until_parked();
3329 fake_model.stream_last_completion_response("Brief".into());
3330 fake_model.stream_last_completion_response(" Introduction".into());
3331 fake_model.end_last_completion_stream();
3332 cx.run_until_parked();
3333
3334 // Summary should be set
3335 thread.read_with(cx, |thread, _| {
3336 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3337 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3338 });
3339
3340 // Now we should be able to set a summary
3341 thread.update(cx, |thread, cx| {
3342 thread.set_summary("Brief Intro", cx);
3343 });
3344
3345 thread.read_with(cx, |thread, _| {
3346 assert_eq!(thread.summary().or_default(), "Brief Intro");
3347 });
3348
3349 // Test setting an empty summary (should default to DEFAULT)
3350 thread.update(cx, |thread, cx| {
3351 thread.set_summary("", cx);
3352 });
3353
3354 thread.read_with(cx, |thread, _| {
3355 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3356 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3357 });
3358 }
3359
3360 #[gpui::test]
3361 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3362 init_test_settings(cx);
3363
3364 let project = create_test_project(cx, json!({})).await;
3365
3366 let (_, _thread_store, thread, _context_store, model) =
3367 setup_test_environment(cx, project.clone()).await;
3368
3369 test_summarize_error(&model, &thread, cx);
3370
3371 // Now we should be able to set a summary
3372 thread.update(cx, |thread, cx| {
3373 thread.set_summary("Brief Intro", cx);
3374 });
3375
3376 thread.read_with(cx, |thread, _| {
3377 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3378 assert_eq!(thread.summary().or_default(), "Brief Intro");
3379 });
3380 }
3381
3382 #[gpui::test]
3383 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3384 init_test_settings(cx);
3385
3386 let project = create_test_project(cx, json!({})).await;
3387
3388 let (_, _thread_store, thread, _context_store, model) =
3389 setup_test_environment(cx, project.clone()).await;
3390
3391 test_summarize_error(&model, &thread, cx);
3392
3393 // Sending another message should not trigger another summarize request
3394 thread.update(cx, |thread, cx| {
3395 thread.insert_user_message(
3396 "How are you?",
3397 ContextLoadResult::default(),
3398 None,
3399 vec![],
3400 cx,
3401 );
3402 thread.send_to_model(model.clone(), None, cx);
3403 });
3404
3405 let fake_model = model.as_fake();
3406 simulate_successful_response(&fake_model, cx);
3407
3408 thread.read_with(cx, |thread, _| {
3409 // State is still Error, not Generating
3410 assert!(matches!(thread.summary(), ThreadSummary::Error));
3411 });
3412
3413 // But the summarize request can be invoked manually
3414 thread.update(cx, |thread, cx| {
3415 thread.summarize(cx);
3416 });
3417
3418 thread.read_with(cx, |thread, _| {
3419 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3420 });
3421
3422 cx.run_until_parked();
3423 fake_model.stream_last_completion_response("A successful summary".into());
3424 fake_model.end_last_completion_stream();
3425 cx.run_until_parked();
3426
3427 thread.read_with(cx, |thread, _| {
3428 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3429 assert_eq!(thread.summary().or_default(), "A successful summary");
3430 });
3431 }
3432
3433 fn test_summarize_error(
3434 model: &Arc<dyn LanguageModel>,
3435 thread: &Entity<Thread>,
3436 cx: &mut TestAppContext,
3437 ) {
3438 thread.update(cx, |thread, cx| {
3439 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3440 thread.send_to_model(model.clone(), None, cx);
3441 });
3442
3443 let fake_model = model.as_fake();
3444 simulate_successful_response(&fake_model, cx);
3445
3446 thread.read_with(cx, |thread, _| {
3447 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3448 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3449 });
3450
3451 // Simulate summary request ending
3452 cx.run_until_parked();
3453 fake_model.end_last_completion_stream();
3454 cx.run_until_parked();
3455
3456 // State is set to Error and default message
3457 thread.read_with(cx, |thread, _| {
3458 assert!(matches!(thread.summary(), ThreadSummary::Error));
3459 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3460 });
3461 }
3462
3463 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3464 cx.run_until_parked();
3465 fake_model.stream_last_completion_response("Assistant response".into());
3466 fake_model.end_last_completion_stream();
3467 cx.run_until_parked();
3468 }
3469
3470 fn init_test_settings(cx: &mut TestAppContext) {
3471 cx.update(|cx| {
3472 let settings_store = SettingsStore::test(cx);
3473 cx.set_global(settings_store);
3474 language::init(cx);
3475 Project::init_settings(cx);
3476 AssistantSettings::register(cx);
3477 prompt_store::init(cx);
3478 thread_store::init(cx);
3479 workspace::init_settings(cx);
3480 language_model::init_settings(cx);
3481 ThemeSettings::register(cx);
3482 EditorSettings::register(cx);
3483 ToolRegistry::default_global(cx);
3484 });
3485 }
3486
3487 // Helper to create a test project with test files
3488 async fn create_test_project(
3489 cx: &mut TestAppContext,
3490 files: serde_json::Value,
3491 ) -> Entity<Project> {
3492 let fs = FakeFs::new(cx.executor());
3493 fs.insert_tree(path!("/test"), files).await;
3494 Project::test(fs, [path!("/test").as_ref()], cx).await
3495 }
3496
3497 async fn setup_test_environment(
3498 cx: &mut TestAppContext,
3499 project: Entity<Project>,
3500 ) -> (
3501 Entity<Workspace>,
3502 Entity<ThreadStore>,
3503 Entity<Thread>,
3504 Entity<ContextStore>,
3505 Arc<dyn LanguageModel>,
3506 ) {
3507 let (workspace, cx) =
3508 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3509
3510 let thread_store = cx
3511 .update(|_, cx| {
3512 ThreadStore::load(
3513 project.clone(),
3514 cx.new(|_| ToolWorkingSet::default()),
3515 None,
3516 Arc::new(PromptBuilder::new(None).unwrap()),
3517 cx,
3518 )
3519 })
3520 .await
3521 .unwrap();
3522
3523 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3524 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3525
3526 let provider = Arc::new(FakeLanguageModelProvider);
3527 let model = provider.test_model();
3528 let model: Arc<dyn LanguageModel> = Arc::new(model);
3529
3530 cx.update(|_, cx| {
3531 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3532 registry.set_default_model(
3533 Some(ConfiguredModel {
3534 provider: provider.clone(),
3535 model: model.clone(),
3536 }),
3537 cx,
3538 );
3539 registry.set_thread_summary_model(
3540 Some(ConfiguredModel {
3541 provider,
3542 model: model.clone(),
3543 }),
3544 cx,
3545 );
3546 })
3547 });
3548
3549 (workspace, thread_store, thread, context_store, model)
3550 }
3551
3552 async fn add_file_to_context(
3553 project: &Entity<Project>,
3554 context_store: &Entity<ContextStore>,
3555 path: &str,
3556 cx: &mut TestAppContext,
3557 ) -> Result<Entity<language::Buffer>> {
3558 let buffer_path = project
3559 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3560 .unwrap();
3561
3562 let buffer = project
3563 .update(cx, |project, cx| {
3564 project.open_buffer(buffer_path.clone(), cx)
3565 })
3566 .await
3567 .unwrap();
3568
3569 context_store.update(cx, |context_store, cx| {
3570 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3571 });
3572
3573 Ok(buffer)
3574 }
3575}