1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: ThreadSummary,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Clone, Debug, PartialEq, Eq)]
365pub enum ThreadSummary {
366 Pending,
367 Generating,
368 Ready(SharedString),
369 Error,
370}
371
372impl ThreadSummary {
373 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
374
375 pub fn or_default(&self) -> SharedString {
376 self.unwrap_or(Self::DEFAULT)
377 }
378
379 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
380 self.ready().unwrap_or_else(|| message.into())
381 }
382
383 pub fn ready(&self) -> Option<SharedString> {
384 match self {
385 ThreadSummary::Ready(summary) => Some(summary.clone()),
386 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct ExceededWindowError {
393 /// Model used when last message exceeded context window
394 model_id: LanguageModelId,
395 /// Token count including last message
396 token_count: usize,
397}
398
399impl Thread {
400 pub fn new(
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 system_prompt: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
408 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
409
410 Self {
411 id: ThreadId::new(),
412 updated_at: Utc::now(),
413 summary: ThreadSummary::Pending,
414 pending_summary: Task::ready(None),
415 detailed_summary_task: Task::ready(None),
416 detailed_summary_tx,
417 detailed_summary_rx,
418 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
419 messages: Vec::new(),
420 next_message_id: MessageId(0),
421 last_prompt_id: PromptId::new(),
422 project_context: system_prompt,
423 checkpoints_by_message: HashMap::default(),
424 completion_count: 0,
425 pending_completions: Vec::new(),
426 project: project.clone(),
427 prompt_builder,
428 tools: tools.clone(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 tool_use: ToolUseState::new(tools.clone()),
432 action_log: cx.new(|_| ActionLog::new(project.clone())),
433 initial_project_snapshot: {
434 let project_snapshot = Self::project_snapshot(project, cx);
435 cx.foreground_executor()
436 .spawn(async move { Some(project_snapshot.await) })
437 .shared()
438 },
439 request_token_usage: Vec::new(),
440 cumulative_token_usage: TokenUsage::default(),
441 exceeded_window_error: None,
442 last_usage: None,
443 tool_use_limit_reached: false,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 last_received_chunk_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 configured_model,
451 }
452 }
453
454 pub fn deserialize(
455 id: ThreadId,
456 serialized: SerializedThread,
457 project: Entity<Project>,
458 tools: Entity<ToolWorkingSet>,
459 prompt_builder: Arc<PromptBuilder>,
460 project_context: SharedProjectContext,
461 window: Option<&mut Window>, // None in headless mode
462 cx: &mut Context<Self>,
463 ) -> Self {
464 let next_message_id = MessageId(
465 serialized
466 .messages
467 .last()
468 .map(|message| message.id.0 + 1)
469 .unwrap_or(0),
470 );
471 let tool_use = ToolUseState::from_serialized_messages(
472 tools.clone(),
473 &serialized.messages,
474 project.clone(),
475 window,
476 cx,
477 );
478 let (detailed_summary_tx, detailed_summary_rx) =
479 postage::watch::channel_with(serialized.detailed_summary_state);
480
481 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
482 serialized
483 .model
484 .and_then(|model| {
485 let model = SelectedModel {
486 provider: model.provider.clone().into(),
487 model: model.model.clone().into(),
488 };
489 registry.select_model(&model, cx)
490 })
491 .or_else(|| registry.default_model())
492 });
493
494 let completion_mode = serialized
495 .completion_mode
496 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
497
498 Self {
499 id,
500 updated_at: serialized.updated_at,
501 summary: ThreadSummary::Ready(serialized.summary),
502 pending_summary: Task::ready(None),
503 detailed_summary_task: Task::ready(None),
504 detailed_summary_tx,
505 detailed_summary_rx,
506 completion_mode,
507 messages: serialized
508 .messages
509 .into_iter()
510 .map(|message| Message {
511 id: message.id,
512 role: message.role,
513 segments: message
514 .segments
515 .into_iter()
516 .map(|segment| match segment {
517 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
518 SerializedMessageSegment::Thinking { text, signature } => {
519 MessageSegment::Thinking { text, signature }
520 }
521 SerializedMessageSegment::RedactedThinking { data } => {
522 MessageSegment::RedactedThinking(data)
523 }
524 })
525 .collect(),
526 loaded_context: LoadedContext {
527 contexts: Vec::new(),
528 text: message.context,
529 images: Vec::new(),
530 },
531 creases: message
532 .creases
533 .into_iter()
534 .map(|crease| MessageCrease {
535 range: crease.start..crease.end,
536 metadata: CreaseMetadata {
537 icon_path: crease.icon_path,
538 label: crease.label,
539 },
540 context: None,
541 })
542 .collect(),
543 })
544 .collect(),
545 next_message_id,
546 last_prompt_id: PromptId::new(),
547 project_context,
548 checkpoints_by_message: HashMap::default(),
549 completion_count: 0,
550 pending_completions: Vec::new(),
551 last_restore_checkpoint: None,
552 pending_checkpoint: None,
553 project: project.clone(),
554 prompt_builder,
555 tools,
556 tool_use,
557 action_log: cx.new(|_| ActionLog::new(project)),
558 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
559 request_token_usage: serialized.request_token_usage,
560 cumulative_token_usage: serialized.cumulative_token_usage,
561 exceeded_window_error: None,
562 last_usage: None,
563 tool_use_limit_reached: false,
564 feedback: None,
565 message_feedback: HashMap::default(),
566 last_auto_capture_at: None,
567 last_received_chunk_at: None,
568 request_callback: None,
569 remaining_turns: u32::MAX,
570 configured_model,
571 }
572 }
573
574 pub fn set_request_callback(
575 &mut self,
576 callback: impl 'static
577 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
578 ) {
579 self.request_callback = Some(Box::new(callback));
580 }
581
582 pub fn id(&self) -> &ThreadId {
583 &self.id
584 }
585
586 pub fn is_empty(&self) -> bool {
587 self.messages.is_empty()
588 }
589
590 pub fn updated_at(&self) -> DateTime<Utc> {
591 self.updated_at
592 }
593
594 pub fn touch_updated_at(&mut self) {
595 self.updated_at = Utc::now();
596 }
597
598 pub fn advance_prompt_id(&mut self) {
599 self.last_prompt_id = PromptId::new();
600 }
601
602 pub fn project_context(&self) -> SharedProjectContext {
603 self.project_context.clone()
604 }
605
606 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
607 if self.configured_model.is_none() {
608 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
609 }
610 self.configured_model.clone()
611 }
612
613 pub fn configured_model(&self) -> Option<ConfiguredModel> {
614 self.configured_model.clone()
615 }
616
617 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
618 self.configured_model = model;
619 cx.notify();
620 }
621
622 pub fn summary(&self) -> &ThreadSummary {
623 &self.summary
624 }
625
626 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
627 let current_summary = match &self.summary {
628 ThreadSummary::Pending | ThreadSummary::Generating => return,
629 ThreadSummary::Ready(summary) => summary,
630 ThreadSummary::Error => &ThreadSummary::DEFAULT,
631 };
632
633 let mut new_summary = new_summary.into();
634
635 if new_summary.is_empty() {
636 new_summary = ThreadSummary::DEFAULT;
637 }
638
639 if current_summary != &new_summary {
640 self.summary = ThreadSummary::Ready(new_summary);
641 cx.emit(ThreadEvent::SummaryChanged);
642 }
643 }
644
645 pub fn completion_mode(&self) -> CompletionMode {
646 self.completion_mode
647 }
648
649 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
650 self.completion_mode = mode;
651 }
652
653 pub fn message(&self, id: MessageId) -> Option<&Message> {
654 let index = self
655 .messages
656 .binary_search_by(|message| message.id.cmp(&id))
657 .ok()?;
658
659 self.messages.get(index)
660 }
661
662 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
663 self.messages.iter()
664 }
665
666 pub fn is_generating(&self) -> bool {
667 !self.pending_completions.is_empty() || !self.all_tools_finished()
668 }
669
670 /// Indicates whether streaming of language model events is stale.
671 /// When `is_generating()` is false, this method returns `None`.
672 pub fn is_generation_stale(&self) -> Option<bool> {
673 const STALE_THRESHOLD: u128 = 250;
674
675 self.last_received_chunk_at
676 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
677 }
678
679 fn received_chunk(&mut self) {
680 self.last_received_chunk_at = Some(Instant::now());
681 }
682
683 pub fn queue_state(&self) -> Option<QueueState> {
684 self.pending_completions
685 .first()
686 .map(|pending_completion| pending_completion.queue_state)
687 }
688
689 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
690 &self.tools
691 }
692
693 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
694 self.tool_use
695 .pending_tool_uses()
696 .into_iter()
697 .find(|tool_use| &tool_use.id == id)
698 }
699
700 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
701 self.tool_use
702 .pending_tool_uses()
703 .into_iter()
704 .filter(|tool_use| tool_use.status.needs_confirmation())
705 }
706
707 pub fn has_pending_tool_uses(&self) -> bool {
708 !self.tool_use.pending_tool_uses().is_empty()
709 }
710
711 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
712 self.checkpoints_by_message.get(&id).cloned()
713 }
714
715 pub fn restore_checkpoint(
716 &mut self,
717 checkpoint: ThreadCheckpoint,
718 cx: &mut Context<Self>,
719 ) -> Task<Result<()>> {
720 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
721 message_id: checkpoint.message_id,
722 });
723 cx.emit(ThreadEvent::CheckpointChanged);
724 cx.notify();
725
726 let git_store = self.project().read(cx).git_store().clone();
727 let restore = git_store.update(cx, |git_store, cx| {
728 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
729 });
730
731 cx.spawn(async move |this, cx| {
732 let result = restore.await;
733 this.update(cx, |this, cx| {
734 if let Err(err) = result.as_ref() {
735 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
736 message_id: checkpoint.message_id,
737 error: err.to_string(),
738 });
739 } else {
740 this.truncate(checkpoint.message_id, cx);
741 this.last_restore_checkpoint = None;
742 }
743 this.pending_checkpoint = None;
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746 })?;
747 result
748 })
749 }
750
751 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
752 let pending_checkpoint = if self.is_generating() {
753 return;
754 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
755 checkpoint
756 } else {
757 return;
758 };
759
760 let git_store = self.project.read(cx).git_store().clone();
761 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
762 cx.spawn(async move |this, cx| match final_checkpoint.await {
763 Ok(final_checkpoint) => {
764 let equal = git_store
765 .update(cx, |store, cx| {
766 store.compare_checkpoints(
767 pending_checkpoint.git_checkpoint.clone(),
768 final_checkpoint.clone(),
769 cx,
770 )
771 })?
772 .await
773 .unwrap_or(false);
774
775 if !equal {
776 this.update(cx, |this, cx| {
777 this.insert_checkpoint(pending_checkpoint, cx)
778 })?;
779 }
780
781 Ok(())
782 }
783 Err(_) => this.update(cx, |this, cx| {
784 this.insert_checkpoint(pending_checkpoint, cx)
785 }),
786 })
787 .detach();
788 }
789
790 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
791 self.checkpoints_by_message
792 .insert(checkpoint.message_id, checkpoint);
793 cx.emit(ThreadEvent::CheckpointChanged);
794 cx.notify();
795 }
796
797 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
798 self.last_restore_checkpoint.as_ref()
799 }
800
801 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
802 let Some(message_ix) = self
803 .messages
804 .iter()
805 .rposition(|message| message.id == message_id)
806 else {
807 return;
808 };
809 for deleted_message in self.messages.drain(message_ix..) {
810 self.checkpoints_by_message.remove(&deleted_message.id);
811 }
812 cx.notify();
813 }
814
815 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
816 self.messages
817 .iter()
818 .find(|message| message.id == id)
819 .into_iter()
820 .flat_map(|message| message.loaded_context.contexts.iter())
821 }
822
823 pub fn is_turn_end(&self, ix: usize) -> bool {
824 if self.messages.is_empty() {
825 return false;
826 }
827
828 if !self.is_generating() && ix == self.messages.len() - 1 {
829 return true;
830 }
831
832 let Some(message) = self.messages.get(ix) else {
833 return false;
834 };
835
836 if message.role != Role::Assistant {
837 return false;
838 }
839
840 self.messages
841 .get(ix + 1)
842 .and_then(|message| {
843 self.message(message.id)
844 .map(|next_message| next_message.role == Role::User)
845 })
846 .unwrap_or(false)
847 }
848
849 pub fn last_usage(&self) -> Option<RequestUsage> {
850 self.last_usage
851 }
852
853 pub fn tool_use_limit_reached(&self) -> bool {
854 self.tool_use_limit_reached
855 }
856
857 /// Returns whether all of the tool uses have finished running.
858 pub fn all_tools_finished(&self) -> bool {
859 // If the only pending tool uses left are the ones with errors, then
860 // that means that we've finished running all of the pending tools.
861 self.tool_use
862 .pending_tool_uses()
863 .iter()
864 .all(|tool_use| tool_use.status.is_error())
865 }
866
867 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
868 self.tool_use.tool_uses_for_message(id, cx)
869 }
870
871 pub fn tool_results_for_message(
872 &self,
873 assistant_message_id: MessageId,
874 ) -> Vec<&LanguageModelToolResult> {
875 self.tool_use.tool_results_for_message(assistant_message_id)
876 }
877
878 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
879 self.tool_use.tool_result(id)
880 }
881
882 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
883 match &self.tool_use.tool_result(id)?.content {
884 LanguageModelToolResultContent::Text(str) => Some(str),
885 LanguageModelToolResultContent::Image(_) => {
886 // TODO: We should display image
887 None
888 }
889 }
890 }
891
892 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
893 self.tool_use.tool_result_card(id).cloned()
894 }
895
896 /// Return tools that are both enabled and supported by the model
897 pub fn available_tools(
898 &self,
899 cx: &App,
900 model: Arc<dyn LanguageModel>,
901 ) -> Vec<LanguageModelRequestTool> {
902 if model.supports_tools() {
903 self.tools()
904 .read(cx)
905 .enabled_tools(cx)
906 .into_iter()
907 .filter_map(|tool| {
908 // Skip tools that cannot be supported
909 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
910 Some(LanguageModelRequestTool {
911 name: tool.name(),
912 description: tool.description(),
913 input_schema,
914 })
915 })
916 .collect()
917 } else {
918 Vec::default()
919 }
920 }
921
922 pub fn insert_user_message(
923 &mut self,
924 text: impl Into<String>,
925 loaded_context: ContextLoadResult,
926 git_checkpoint: Option<GitStoreCheckpoint>,
927 creases: Vec<MessageCrease>,
928 cx: &mut Context<Self>,
929 ) -> MessageId {
930 if !loaded_context.referenced_buffers.is_empty() {
931 self.action_log.update(cx, |log, cx| {
932 for buffer in loaded_context.referenced_buffers {
933 log.buffer_read(buffer, cx);
934 }
935 });
936 }
937
938 let message_id = self.insert_message(
939 Role::User,
940 vec![MessageSegment::Text(text.into())],
941 loaded_context.loaded_context,
942 creases,
943 cx,
944 );
945
946 if let Some(git_checkpoint) = git_checkpoint {
947 self.pending_checkpoint = Some(ThreadCheckpoint {
948 message_id,
949 git_checkpoint,
950 });
951 }
952
953 self.auto_capture_telemetry(cx);
954
955 message_id
956 }
957
958 pub fn insert_assistant_message(
959 &mut self,
960 segments: Vec<MessageSegment>,
961 cx: &mut Context<Self>,
962 ) -> MessageId {
963 self.insert_message(
964 Role::Assistant,
965 segments,
966 LoadedContext::default(),
967 Vec::new(),
968 cx,
969 )
970 }
971
972 pub fn insert_message(
973 &mut self,
974 role: Role,
975 segments: Vec<MessageSegment>,
976 loaded_context: LoadedContext,
977 creases: Vec<MessageCrease>,
978 cx: &mut Context<Self>,
979 ) -> MessageId {
980 let id = self.next_message_id.post_inc();
981 self.messages.push(Message {
982 id,
983 role,
984 segments,
985 loaded_context,
986 creases,
987 });
988 self.touch_updated_at();
989 cx.emit(ThreadEvent::MessageAdded(id));
990 id
991 }
992
993 pub fn edit_message(
994 &mut self,
995 id: MessageId,
996 new_role: Role,
997 new_segments: Vec<MessageSegment>,
998 loaded_context: Option<LoadedContext>,
999 cx: &mut Context<Self>,
1000 ) -> bool {
1001 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1002 return false;
1003 };
1004 message.role = new_role;
1005 message.segments = new_segments;
1006 if let Some(context) = loaded_context {
1007 message.loaded_context = context;
1008 }
1009 self.touch_updated_at();
1010 cx.emit(ThreadEvent::MessageEdited(id));
1011 true
1012 }
1013
1014 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1015 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1016 return false;
1017 };
1018 self.messages.remove(index);
1019 self.touch_updated_at();
1020 cx.emit(ThreadEvent::MessageDeleted(id));
1021 true
1022 }
1023
1024 /// Returns the representation of this [`Thread`] in a textual form.
1025 ///
1026 /// This is the representation we use when attaching a thread as context to another thread.
1027 pub fn text(&self) -> String {
1028 let mut text = String::new();
1029
1030 for message in &self.messages {
1031 text.push_str(match message.role {
1032 language_model::Role::User => "User:",
1033 language_model::Role::Assistant => "Agent:",
1034 language_model::Role::System => "System:",
1035 });
1036 text.push('\n');
1037
1038 for segment in &message.segments {
1039 match segment {
1040 MessageSegment::Text(content) => text.push_str(content),
1041 MessageSegment::Thinking { text: content, .. } => {
1042 text.push_str(&format!("<think>{}</think>", content))
1043 }
1044 MessageSegment::RedactedThinking(_) => {}
1045 }
1046 }
1047 text.push('\n');
1048 }
1049
1050 text
1051 }
1052
1053 /// Serializes this thread into a format for storage or telemetry.
1054 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1055 let initial_project_snapshot = self.initial_project_snapshot.clone();
1056 cx.spawn(async move |this, cx| {
1057 let initial_project_snapshot = initial_project_snapshot.await;
1058 this.read_with(cx, |this, cx| SerializedThread {
1059 version: SerializedThread::VERSION.to_string(),
1060 summary: this.summary().or_default(),
1061 updated_at: this.updated_at(),
1062 messages: this
1063 .messages()
1064 .map(|message| SerializedMessage {
1065 id: message.id,
1066 role: message.role,
1067 segments: message
1068 .segments
1069 .iter()
1070 .map(|segment| match segment {
1071 MessageSegment::Text(text) => {
1072 SerializedMessageSegment::Text { text: text.clone() }
1073 }
1074 MessageSegment::Thinking { text, signature } => {
1075 SerializedMessageSegment::Thinking {
1076 text: text.clone(),
1077 signature: signature.clone(),
1078 }
1079 }
1080 MessageSegment::RedactedThinking(data) => {
1081 SerializedMessageSegment::RedactedThinking {
1082 data: data.clone(),
1083 }
1084 }
1085 })
1086 .collect(),
1087 tool_uses: this
1088 .tool_uses_for_message(message.id, cx)
1089 .into_iter()
1090 .map(|tool_use| SerializedToolUse {
1091 id: tool_use.id,
1092 name: tool_use.name,
1093 input: tool_use.input,
1094 })
1095 .collect(),
1096 tool_results: this
1097 .tool_results_for_message(message.id)
1098 .into_iter()
1099 .map(|tool_result| SerializedToolResult {
1100 tool_use_id: tool_result.tool_use_id.clone(),
1101 is_error: tool_result.is_error,
1102 content: tool_result.content.clone(),
1103 output: tool_result.output.clone(),
1104 })
1105 .collect(),
1106 context: message.loaded_context.text.clone(),
1107 creases: message
1108 .creases
1109 .iter()
1110 .map(|crease| SerializedCrease {
1111 start: crease.range.start,
1112 end: crease.range.end,
1113 icon_path: crease.metadata.icon_path.clone(),
1114 label: crease.metadata.label.clone(),
1115 })
1116 .collect(),
1117 })
1118 .collect(),
1119 initial_project_snapshot,
1120 cumulative_token_usage: this.cumulative_token_usage,
1121 request_token_usage: this.request_token_usage.clone(),
1122 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1123 exceeded_window_error: this.exceeded_window_error.clone(),
1124 model: this
1125 .configured_model
1126 .as_ref()
1127 .map(|model| SerializedLanguageModel {
1128 provider: model.provider.id().0.to_string(),
1129 model: model.model.id().0.to_string(),
1130 }),
1131 completion_mode: Some(this.completion_mode),
1132 })
1133 })
1134 }
1135
1136 pub fn remaining_turns(&self) -> u32 {
1137 self.remaining_turns
1138 }
1139
1140 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1141 self.remaining_turns = remaining_turns;
1142 }
1143
1144 pub fn send_to_model(
1145 &mut self,
1146 model: Arc<dyn LanguageModel>,
1147 window: Option<AnyWindowHandle>,
1148 cx: &mut Context<Self>,
1149 ) {
1150 if self.remaining_turns == 0 {
1151 return;
1152 }
1153
1154 self.remaining_turns -= 1;
1155
1156 let request = self.to_completion_request(model.clone(), cx);
1157
1158 self.stream_completion(request, model, window, cx);
1159 }
1160
1161 pub fn used_tools_since_last_user_message(&self) -> bool {
1162 for message in self.messages.iter().rev() {
1163 if self.tool_use.message_has_tool_results(message.id) {
1164 return true;
1165 } else if message.role == Role::User {
1166 return false;
1167 }
1168 }
1169
1170 false
1171 }
1172
1173 pub fn to_completion_request(
1174 &self,
1175 model: Arc<dyn LanguageModel>,
1176 cx: &mut Context<Self>,
1177 ) -> LanguageModelRequest {
1178 let mut request = LanguageModelRequest {
1179 thread_id: Some(self.id.to_string()),
1180 prompt_id: Some(self.last_prompt_id.to_string()),
1181 mode: None,
1182 messages: vec![],
1183 tools: Vec::new(),
1184 tool_choice: None,
1185 stop: Vec::new(),
1186 temperature: AssistantSettings::temperature_for_model(&model, cx),
1187 };
1188
1189 let available_tools = self.available_tools(cx, model.clone());
1190 let available_tool_names = available_tools
1191 .iter()
1192 .map(|tool| tool.name.clone())
1193 .collect();
1194
1195 let model_context = &ModelContext {
1196 available_tools: available_tool_names,
1197 };
1198
1199 if let Some(project_context) = self.project_context.borrow().as_ref() {
1200 match self
1201 .prompt_builder
1202 .generate_assistant_system_prompt(project_context, model_context)
1203 {
1204 Err(err) => {
1205 let message = format!("{err:?}").into();
1206 log::error!("{message}");
1207 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1208 header: "Error generating system prompt".into(),
1209 message,
1210 }));
1211 }
1212 Ok(system_prompt) => {
1213 request.messages.push(LanguageModelRequestMessage {
1214 role: Role::System,
1215 content: vec![MessageContent::Text(system_prompt)],
1216 cache: true,
1217 });
1218 }
1219 }
1220 } else {
1221 let message = "Context for system prompt unexpectedly not ready.".into();
1222 log::error!("{message}");
1223 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1224 header: "Error generating system prompt".into(),
1225 message,
1226 }));
1227 }
1228
1229 let mut message_ix_to_cache = None;
1230 for message in &self.messages {
1231 let mut request_message = LanguageModelRequestMessage {
1232 role: message.role,
1233 content: Vec::new(),
1234 cache: false,
1235 };
1236
1237 message
1238 .loaded_context
1239 .add_to_request_message(&mut request_message);
1240
1241 for segment in &message.segments {
1242 match segment {
1243 MessageSegment::Text(text) => {
1244 if !text.is_empty() {
1245 request_message
1246 .content
1247 .push(MessageContent::Text(text.into()));
1248 }
1249 }
1250 MessageSegment::Thinking { text, signature } => {
1251 if !text.is_empty() {
1252 request_message.content.push(MessageContent::Thinking {
1253 text: text.into(),
1254 signature: signature.clone(),
1255 });
1256 }
1257 }
1258 MessageSegment::RedactedThinking(data) => {
1259 request_message
1260 .content
1261 .push(MessageContent::RedactedThinking(data.clone()));
1262 }
1263 };
1264 }
1265
1266 let mut cache_message = true;
1267 let mut tool_results_message = LanguageModelRequestMessage {
1268 role: Role::User,
1269 content: Vec::new(),
1270 cache: false,
1271 };
1272 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1273 if let Some(tool_result) = tool_result {
1274 request_message
1275 .content
1276 .push(MessageContent::ToolUse(tool_use.clone()));
1277 tool_results_message
1278 .content
1279 .push(MessageContent::ToolResult(LanguageModelToolResult {
1280 tool_use_id: tool_use.id.clone(),
1281 tool_name: tool_result.tool_name.clone(),
1282 is_error: tool_result.is_error,
1283 content: if tool_result.content.is_empty() {
1284 // Surprisingly, the API fails if we return an empty string here.
1285 // It thinks we are sending a tool use without a tool result.
1286 "<Tool returned an empty string>".into()
1287 } else {
1288 tool_result.content.clone()
1289 },
1290 output: None,
1291 }));
1292 } else {
1293 cache_message = false;
1294 log::debug!(
1295 "skipped tool use {:?} because it is still pending",
1296 tool_use
1297 );
1298 }
1299 }
1300
1301 if cache_message {
1302 message_ix_to_cache = Some(request.messages.len());
1303 }
1304 request.messages.push(request_message);
1305
1306 if !tool_results_message.content.is_empty() {
1307 if cache_message {
1308 message_ix_to_cache = Some(request.messages.len());
1309 }
1310 request.messages.push(tool_results_message);
1311 }
1312 }
1313
1314 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1315 if let Some(message_ix_to_cache) = message_ix_to_cache {
1316 request.messages[message_ix_to_cache].cache = true;
1317 }
1318
1319 self.attached_tracked_files_state(&mut request.messages, cx);
1320
1321 request.tools = available_tools;
1322 request.mode = if model.supports_max_mode() {
1323 Some(self.completion_mode.into())
1324 } else {
1325 Some(CompletionMode::Normal.into())
1326 };
1327
1328 request
1329 }
1330
1331 fn to_summarize_request(
1332 &self,
1333 model: &Arc<dyn LanguageModel>,
1334 added_user_message: String,
1335 cx: &App,
1336 ) -> LanguageModelRequest {
1337 let mut request = LanguageModelRequest {
1338 thread_id: None,
1339 prompt_id: None,
1340 mode: None,
1341 messages: vec![],
1342 tools: Vec::new(),
1343 tool_choice: None,
1344 stop: Vec::new(),
1345 temperature: AssistantSettings::temperature_for_model(model, cx),
1346 };
1347
1348 for message in &self.messages {
1349 let mut request_message = LanguageModelRequestMessage {
1350 role: message.role,
1351 content: Vec::new(),
1352 cache: false,
1353 };
1354
1355 for segment in &message.segments {
1356 match segment {
1357 MessageSegment::Text(text) => request_message
1358 .content
1359 .push(MessageContent::Text(text.clone())),
1360 MessageSegment::Thinking { .. } => {}
1361 MessageSegment::RedactedThinking(_) => {}
1362 }
1363 }
1364
1365 if request_message.content.is_empty() {
1366 continue;
1367 }
1368
1369 request.messages.push(request_message);
1370 }
1371
1372 request.messages.push(LanguageModelRequestMessage {
1373 role: Role::User,
1374 content: vec![MessageContent::Text(added_user_message)],
1375 cache: false,
1376 });
1377
1378 request
1379 }
1380
1381 fn attached_tracked_files_state(
1382 &self,
1383 messages: &mut Vec<LanguageModelRequestMessage>,
1384 cx: &App,
1385 ) {
1386 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1387
1388 let mut stale_message = String::new();
1389
1390 let action_log = self.action_log.read(cx);
1391
1392 for stale_file in action_log.stale_buffers(cx) {
1393 let Some(file) = stale_file.read(cx).file() else {
1394 continue;
1395 };
1396
1397 if stale_message.is_empty() {
1398 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1399 }
1400
1401 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1402 }
1403
1404 let mut content = Vec::with_capacity(2);
1405
1406 if !stale_message.is_empty() {
1407 content.push(stale_message.into());
1408 }
1409
1410 if !content.is_empty() {
1411 let context_message = LanguageModelRequestMessage {
1412 role: Role::User,
1413 content,
1414 cache: false,
1415 };
1416
1417 messages.push(context_message);
1418 }
1419 }
1420
1421 pub fn stream_completion(
1422 &mut self,
1423 request: LanguageModelRequest,
1424 model: Arc<dyn LanguageModel>,
1425 window: Option<AnyWindowHandle>,
1426 cx: &mut Context<Self>,
1427 ) {
1428 self.tool_use_limit_reached = false;
1429
1430 let pending_completion_id = post_inc(&mut self.completion_count);
1431 let mut request_callback_parameters = if self.request_callback.is_some() {
1432 Some((request.clone(), Vec::new()))
1433 } else {
1434 None
1435 };
1436 let prompt_id = self.last_prompt_id.clone();
1437 let tool_use_metadata = ToolUseMetadata {
1438 model: model.clone(),
1439 thread_id: self.id.clone(),
1440 prompt_id: prompt_id.clone(),
1441 };
1442
1443 self.last_received_chunk_at = Some(Instant::now());
1444
1445 let task = cx.spawn(async move |thread, cx| {
1446 let stream_completion_future = model.stream_completion(request, &cx);
1447 let initial_token_usage =
1448 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1449 let stream_completion = async {
1450 let mut events = stream_completion_future.await?;
1451
1452 let mut stop_reason = StopReason::EndTurn;
1453 let mut current_token_usage = TokenUsage::default();
1454
1455 thread
1456 .update(cx, |_thread, cx| {
1457 cx.emit(ThreadEvent::NewRequest);
1458 })
1459 .ok();
1460
1461 let mut request_assistant_message_id = None;
1462
1463 while let Some(event) = events.next().await {
1464 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1465 response_events
1466 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1467 }
1468
1469 thread.update(cx, |thread, cx| {
1470 let event = match event {
1471 Ok(event) => event,
1472 Err(LanguageModelCompletionError::BadInputJson {
1473 id,
1474 tool_name,
1475 raw_input: invalid_input_json,
1476 json_parse_error,
1477 }) => {
1478 thread.receive_invalid_tool_json(
1479 id,
1480 tool_name,
1481 invalid_input_json,
1482 json_parse_error,
1483 window,
1484 cx,
1485 );
1486 return Ok(());
1487 }
1488 Err(LanguageModelCompletionError::Other(error)) => {
1489 return Err(error);
1490 }
1491 };
1492
1493 match event {
1494 LanguageModelCompletionEvent::StartMessage { .. } => {
1495 request_assistant_message_id =
1496 Some(thread.insert_assistant_message(
1497 vec![MessageSegment::Text(String::new())],
1498 cx,
1499 ));
1500 }
1501 LanguageModelCompletionEvent::Stop(reason) => {
1502 stop_reason = reason;
1503 }
1504 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1505 thread.update_token_usage_at_last_message(token_usage);
1506 thread.cumulative_token_usage = thread.cumulative_token_usage
1507 + token_usage
1508 - current_token_usage;
1509 current_token_usage = token_usage;
1510 }
1511 LanguageModelCompletionEvent::Text(chunk) => {
1512 thread.received_chunk();
1513
1514 cx.emit(ThreadEvent::ReceivedTextChunk);
1515 if let Some(last_message) = thread.messages.last_mut() {
1516 if last_message.role == Role::Assistant
1517 && !thread.tool_use.has_tool_results(last_message.id)
1518 {
1519 last_message.push_text(&chunk);
1520 cx.emit(ThreadEvent::StreamedAssistantText(
1521 last_message.id,
1522 chunk,
1523 ));
1524 } else {
1525 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1526 // of a new Assistant response.
1527 //
1528 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1529 // will result in duplicating the text of the chunk in the rendered Markdown.
1530 request_assistant_message_id =
1531 Some(thread.insert_assistant_message(
1532 vec![MessageSegment::Text(chunk.to_string())],
1533 cx,
1534 ));
1535 };
1536 }
1537 }
1538 LanguageModelCompletionEvent::Thinking {
1539 text: chunk,
1540 signature,
1541 } => {
1542 thread.received_chunk();
1543
1544 if let Some(last_message) = thread.messages.last_mut() {
1545 if last_message.role == Role::Assistant
1546 && !thread.tool_use.has_tool_results(last_message.id)
1547 {
1548 last_message.push_thinking(&chunk, signature);
1549 cx.emit(ThreadEvent::StreamedAssistantThinking(
1550 last_message.id,
1551 chunk,
1552 ));
1553 } else {
1554 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1555 // of a new Assistant response.
1556 //
1557 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1558 // will result in duplicating the text of the chunk in the rendered Markdown.
1559 request_assistant_message_id =
1560 Some(thread.insert_assistant_message(
1561 vec![MessageSegment::Thinking {
1562 text: chunk.to_string(),
1563 signature,
1564 }],
1565 cx,
1566 ));
1567 };
1568 }
1569 }
1570 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1571 let last_assistant_message_id = request_assistant_message_id
1572 .unwrap_or_else(|| {
1573 let new_assistant_message_id =
1574 thread.insert_assistant_message(vec![], cx);
1575 request_assistant_message_id =
1576 Some(new_assistant_message_id);
1577 new_assistant_message_id
1578 });
1579
1580 let tool_use_id = tool_use.id.clone();
1581 let streamed_input = if tool_use.is_input_complete {
1582 None
1583 } else {
1584 Some((&tool_use.input).clone())
1585 };
1586
1587 let ui_text = thread.tool_use.request_tool_use(
1588 last_assistant_message_id,
1589 tool_use,
1590 tool_use_metadata.clone(),
1591 cx,
1592 );
1593
1594 if let Some(input) = streamed_input {
1595 cx.emit(ThreadEvent::StreamedToolUse {
1596 tool_use_id,
1597 ui_text,
1598 input,
1599 });
1600 }
1601 }
1602 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1603 if let Some(completion) = thread
1604 .pending_completions
1605 .iter_mut()
1606 .find(|completion| completion.id == pending_completion_id)
1607 {
1608 match status_update {
1609 CompletionRequestStatus::Queued {
1610 position,
1611 } => {
1612 completion.queue_state = QueueState::Queued { position };
1613 }
1614 CompletionRequestStatus::Started => {
1615 completion.queue_state = QueueState::Started;
1616 }
1617 CompletionRequestStatus::Failed {
1618 code, message, request_id
1619 } => {
1620 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1621 }
1622 CompletionRequestStatus::UsageUpdated {
1623 amount, limit
1624 } => {
1625 let usage = RequestUsage { limit, amount: amount as i32 };
1626
1627 thread.last_usage = Some(usage);
1628 }
1629 CompletionRequestStatus::ToolUseLimitReached => {
1630 thread.tool_use_limit_reached = true;
1631 }
1632 }
1633 }
1634 }
1635 }
1636
1637 thread.touch_updated_at();
1638 cx.emit(ThreadEvent::StreamedCompletion);
1639 cx.notify();
1640
1641 thread.auto_capture_telemetry(cx);
1642 Ok(())
1643 })??;
1644
1645 smol::future::yield_now().await;
1646 }
1647
1648 thread.update(cx, |thread, cx| {
1649 thread.last_received_chunk_at = None;
1650 thread
1651 .pending_completions
1652 .retain(|completion| completion.id != pending_completion_id);
1653
1654 // If there is a response without tool use, summarize the message. Otherwise,
1655 // allow two tool uses before summarizing.
1656 if matches!(thread.summary, ThreadSummary::Pending)
1657 && thread.messages.len() >= 2
1658 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1659 {
1660 thread.summarize(cx);
1661 }
1662 })?;
1663
1664 anyhow::Ok(stop_reason)
1665 };
1666
1667 let result = stream_completion.await;
1668
1669 thread
1670 .update(cx, |thread, cx| {
1671 thread.finalize_pending_checkpoint(cx);
1672 match result.as_ref() {
1673 Ok(stop_reason) => match stop_reason {
1674 StopReason::ToolUse => {
1675 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1676 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1677 }
1678 StopReason::EndTurn | StopReason::MaxTokens => {
1679 thread.project.update(cx, |project, cx| {
1680 project.set_agent_location(None, cx);
1681 });
1682 }
1683 },
1684 Err(error) => {
1685 thread.project.update(cx, |project, cx| {
1686 project.set_agent_location(None, cx);
1687 });
1688
1689 if error.is::<PaymentRequiredError>() {
1690 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1691 } else if let Some(error) =
1692 error.downcast_ref::<ModelRequestLimitReachedError>()
1693 {
1694 cx.emit(ThreadEvent::ShowError(
1695 ThreadError::ModelRequestLimitReached { plan: error.plan },
1696 ));
1697 } else if let Some(known_error) =
1698 error.downcast_ref::<LanguageModelKnownError>()
1699 {
1700 match known_error {
1701 LanguageModelKnownError::ContextWindowLimitExceeded {
1702 tokens,
1703 } => {
1704 thread.exceeded_window_error = Some(ExceededWindowError {
1705 model_id: model.id(),
1706 token_count: *tokens,
1707 });
1708 cx.notify();
1709 }
1710 }
1711 } else {
1712 let error_message = error
1713 .chain()
1714 .map(|err| err.to_string())
1715 .collect::<Vec<_>>()
1716 .join("\n");
1717 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1718 header: "Error interacting with language model".into(),
1719 message: SharedString::from(error_message.clone()),
1720 }));
1721 }
1722
1723 thread.cancel_last_completion(window, cx);
1724 }
1725 }
1726 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1727
1728 if let Some((request_callback, (request, response_events))) = thread
1729 .request_callback
1730 .as_mut()
1731 .zip(request_callback_parameters.as_ref())
1732 {
1733 request_callback(request, response_events);
1734 }
1735
1736 thread.auto_capture_telemetry(cx);
1737
1738 if let Ok(initial_usage) = initial_token_usage {
1739 let usage = thread.cumulative_token_usage - initial_usage;
1740
1741 telemetry::event!(
1742 "Assistant Thread Completion",
1743 thread_id = thread.id().to_string(),
1744 prompt_id = prompt_id,
1745 model = model.telemetry_id(),
1746 model_provider = model.provider_id().to_string(),
1747 input_tokens = usage.input_tokens,
1748 output_tokens = usage.output_tokens,
1749 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1750 cache_read_input_tokens = usage.cache_read_input_tokens,
1751 );
1752 }
1753 })
1754 .ok();
1755 });
1756
1757 self.pending_completions.push(PendingCompletion {
1758 id: pending_completion_id,
1759 queue_state: QueueState::Sending,
1760 _task: task,
1761 });
1762 }
1763
1764 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1765 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1766 println!("No thread summary model");
1767 return;
1768 };
1769
1770 if !model.provider.is_authenticated(cx) {
1771 return;
1772 }
1773
1774 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1775 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1776 If the conversation is about a specific subject, include it in the title. \
1777 Be descriptive. DO NOT speak in the first person.";
1778
1779 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1780
1781 self.summary = ThreadSummary::Generating;
1782
1783 self.pending_summary = cx.spawn(async move |this, cx| {
1784 let result = async {
1785 let mut messages = model.model.stream_completion(request, &cx).await?;
1786
1787 let mut new_summary = String::new();
1788 while let Some(event) = messages.next().await {
1789 let Ok(event) = event else {
1790 continue;
1791 };
1792 let text = match event {
1793 LanguageModelCompletionEvent::Text(text) => text,
1794 LanguageModelCompletionEvent::StatusUpdate(
1795 CompletionRequestStatus::UsageUpdated { amount, limit },
1796 ) => {
1797 this.update(cx, |thread, _cx| {
1798 thread.last_usage = Some(RequestUsage {
1799 limit,
1800 amount: amount as i32,
1801 });
1802 })?;
1803 continue;
1804 }
1805 _ => continue,
1806 };
1807
1808 let mut lines = text.lines();
1809 new_summary.extend(lines.next());
1810
1811 // Stop if the LLM generated multiple lines.
1812 if lines.next().is_some() {
1813 break;
1814 }
1815 }
1816
1817 anyhow::Ok(new_summary)
1818 }
1819 .await;
1820
1821 this.update(cx, |this, cx| {
1822 match result {
1823 Ok(new_summary) => {
1824 if new_summary.is_empty() {
1825 this.summary = ThreadSummary::Error;
1826 } else {
1827 this.summary = ThreadSummary::Ready(new_summary.into());
1828 }
1829 }
1830 Err(err) => {
1831 this.summary = ThreadSummary::Error;
1832 log::error!("Failed to generate thread summary: {}", err);
1833 }
1834 }
1835 cx.emit(ThreadEvent::SummaryGenerated);
1836 })
1837 .log_err()?;
1838
1839 Some(())
1840 });
1841 }
1842
1843 pub fn start_generating_detailed_summary_if_needed(
1844 &mut self,
1845 thread_store: WeakEntity<ThreadStore>,
1846 cx: &mut Context<Self>,
1847 ) {
1848 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1849 return;
1850 };
1851
1852 match &*self.detailed_summary_rx.borrow() {
1853 DetailedSummaryState::Generating { message_id, .. }
1854 | DetailedSummaryState::Generated { message_id, .. }
1855 if *message_id == last_message_id =>
1856 {
1857 // Already up-to-date
1858 return;
1859 }
1860 _ => {}
1861 }
1862
1863 let Some(ConfiguredModel { model, provider }) =
1864 LanguageModelRegistry::read_global(cx).thread_summary_model()
1865 else {
1866 return;
1867 };
1868
1869 if !provider.is_authenticated(cx) {
1870 return;
1871 }
1872
1873 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1874 1. A brief overview of what was discussed\n\
1875 2. Key facts or information discovered\n\
1876 3. Outcomes or conclusions reached\n\
1877 4. Any action items or next steps if any\n\
1878 Format it in Markdown with headings and bullet points.";
1879
1880 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1881
1882 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1883 message_id: last_message_id,
1884 };
1885
1886 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1887 // be better to allow the old task to complete, but this would require logic for choosing
1888 // which result to prefer (the old task could complete after the new one, resulting in a
1889 // stale summary).
1890 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1891 let stream = model.stream_completion_text(request, &cx);
1892 let Some(mut messages) = stream.await.log_err() else {
1893 thread
1894 .update(cx, |thread, _cx| {
1895 *thread.detailed_summary_tx.borrow_mut() =
1896 DetailedSummaryState::NotGenerated;
1897 })
1898 .ok()?;
1899 return None;
1900 };
1901
1902 let mut new_detailed_summary = String::new();
1903
1904 while let Some(chunk) = messages.stream.next().await {
1905 if let Some(chunk) = chunk.log_err() {
1906 new_detailed_summary.push_str(&chunk);
1907 }
1908 }
1909
1910 thread
1911 .update(cx, |thread, _cx| {
1912 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1913 text: new_detailed_summary.into(),
1914 message_id: last_message_id,
1915 };
1916 })
1917 .ok()?;
1918
1919 // Save thread so its summary can be reused later
1920 if let Some(thread) = thread.upgrade() {
1921 if let Ok(Ok(save_task)) = cx.update(|cx| {
1922 thread_store
1923 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1924 }) {
1925 save_task.await.log_err();
1926 }
1927 }
1928
1929 Some(())
1930 });
1931 }
1932
1933 pub async fn wait_for_detailed_summary_or_text(
1934 this: &Entity<Self>,
1935 cx: &mut AsyncApp,
1936 ) -> Option<SharedString> {
1937 let mut detailed_summary_rx = this
1938 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1939 .ok()?;
1940 loop {
1941 match detailed_summary_rx.recv().await? {
1942 DetailedSummaryState::Generating { .. } => {}
1943 DetailedSummaryState::NotGenerated => {
1944 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1945 }
1946 DetailedSummaryState::Generated { text, .. } => return Some(text),
1947 }
1948 }
1949 }
1950
1951 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1952 self.detailed_summary_rx
1953 .borrow()
1954 .text()
1955 .unwrap_or_else(|| self.text().into())
1956 }
1957
1958 pub fn is_generating_detailed_summary(&self) -> bool {
1959 matches!(
1960 &*self.detailed_summary_rx.borrow(),
1961 DetailedSummaryState::Generating { .. }
1962 )
1963 }
1964
1965 pub fn use_pending_tools(
1966 &mut self,
1967 window: Option<AnyWindowHandle>,
1968 cx: &mut Context<Self>,
1969 model: Arc<dyn LanguageModel>,
1970 ) -> Vec<PendingToolUse> {
1971 self.auto_capture_telemetry(cx);
1972 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1973 let pending_tool_uses = self
1974 .tool_use
1975 .pending_tool_uses()
1976 .into_iter()
1977 .filter(|tool_use| tool_use.status.is_idle())
1978 .cloned()
1979 .collect::<Vec<_>>();
1980
1981 for tool_use in pending_tool_uses.iter() {
1982 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1983 if tool.needs_confirmation(&tool_use.input, cx)
1984 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1985 {
1986 self.tool_use.confirm_tool_use(
1987 tool_use.id.clone(),
1988 tool_use.ui_text.clone(),
1989 tool_use.input.clone(),
1990 request.clone(),
1991 tool,
1992 );
1993 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1994 } else {
1995 self.run_tool(
1996 tool_use.id.clone(),
1997 tool_use.ui_text.clone(),
1998 tool_use.input.clone(),
1999 request.clone(),
2000 tool,
2001 model.clone(),
2002 window,
2003 cx,
2004 );
2005 }
2006 } else {
2007 self.handle_hallucinated_tool_use(
2008 tool_use.id.clone(),
2009 tool_use.name.clone(),
2010 window,
2011 cx,
2012 );
2013 }
2014 }
2015
2016 pending_tool_uses
2017 }
2018
2019 pub fn handle_hallucinated_tool_use(
2020 &mut self,
2021 tool_use_id: LanguageModelToolUseId,
2022 hallucinated_tool_name: Arc<str>,
2023 window: Option<AnyWindowHandle>,
2024 cx: &mut Context<Thread>,
2025 ) {
2026 let available_tools = self.tools.read(cx).enabled_tools(cx);
2027
2028 let tool_list = available_tools
2029 .iter()
2030 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2031 .collect::<Vec<_>>()
2032 .join("\n");
2033
2034 let error_message = format!(
2035 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2036 hallucinated_tool_name, tool_list
2037 );
2038
2039 let pending_tool_use = self.tool_use.insert_tool_output(
2040 tool_use_id.clone(),
2041 hallucinated_tool_name,
2042 Err(anyhow!("Missing tool call: {error_message}")),
2043 self.configured_model.as_ref(),
2044 );
2045
2046 cx.emit(ThreadEvent::MissingToolUse {
2047 tool_use_id: tool_use_id.clone(),
2048 ui_text: error_message.into(),
2049 });
2050
2051 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2052 }
2053
2054 pub fn receive_invalid_tool_json(
2055 &mut self,
2056 tool_use_id: LanguageModelToolUseId,
2057 tool_name: Arc<str>,
2058 invalid_json: Arc<str>,
2059 error: String,
2060 window: Option<AnyWindowHandle>,
2061 cx: &mut Context<Thread>,
2062 ) {
2063 log::error!("The model returned invalid input JSON: {invalid_json}");
2064
2065 let pending_tool_use = self.tool_use.insert_tool_output(
2066 tool_use_id.clone(),
2067 tool_name,
2068 Err(anyhow!("Error parsing input JSON: {error}")),
2069 self.configured_model.as_ref(),
2070 );
2071 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2072 pending_tool_use.ui_text.clone()
2073 } else {
2074 log::error!(
2075 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2076 );
2077 format!("Unknown tool {}", tool_use_id).into()
2078 };
2079
2080 cx.emit(ThreadEvent::InvalidToolInput {
2081 tool_use_id: tool_use_id.clone(),
2082 ui_text,
2083 invalid_input_json: invalid_json,
2084 });
2085
2086 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2087 }
2088
2089 pub fn run_tool(
2090 &mut self,
2091 tool_use_id: LanguageModelToolUseId,
2092 ui_text: impl Into<SharedString>,
2093 input: serde_json::Value,
2094 request: Arc<LanguageModelRequest>,
2095 tool: Arc<dyn Tool>,
2096 model: Arc<dyn LanguageModel>,
2097 window: Option<AnyWindowHandle>,
2098 cx: &mut Context<Thread>,
2099 ) {
2100 let task =
2101 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2102 self.tool_use
2103 .run_pending_tool(tool_use_id, ui_text.into(), task);
2104 }
2105
2106 fn spawn_tool_use(
2107 &mut self,
2108 tool_use_id: LanguageModelToolUseId,
2109 request: Arc<LanguageModelRequest>,
2110 input: serde_json::Value,
2111 tool: Arc<dyn Tool>,
2112 model: Arc<dyn LanguageModel>,
2113 window: Option<AnyWindowHandle>,
2114 cx: &mut Context<Thread>,
2115 ) -> Task<()> {
2116 let tool_name: Arc<str> = tool.name().into();
2117
2118 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2119 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2120 } else {
2121 tool.run(
2122 input,
2123 request,
2124 self.project.clone(),
2125 self.action_log.clone(),
2126 model,
2127 window,
2128 cx,
2129 )
2130 };
2131
2132 // Store the card separately if it exists
2133 if let Some(card) = tool_result.card.clone() {
2134 self.tool_use
2135 .insert_tool_result_card(tool_use_id.clone(), card);
2136 }
2137
2138 cx.spawn({
2139 async move |thread: WeakEntity<Thread>, cx| {
2140 let output = tool_result.output.await;
2141
2142 thread
2143 .update(cx, |thread, cx| {
2144 let pending_tool_use = thread.tool_use.insert_tool_output(
2145 tool_use_id.clone(),
2146 tool_name,
2147 output,
2148 thread.configured_model.as_ref(),
2149 );
2150 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2151 })
2152 .ok();
2153 }
2154 })
2155 }
2156
2157 fn tool_finished(
2158 &mut self,
2159 tool_use_id: LanguageModelToolUseId,
2160 pending_tool_use: Option<PendingToolUse>,
2161 canceled: bool,
2162 window: Option<AnyWindowHandle>,
2163 cx: &mut Context<Self>,
2164 ) {
2165 if self.all_tools_finished() {
2166 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2167 if !canceled {
2168 self.send_to_model(model.clone(), window, cx);
2169 }
2170 self.auto_capture_telemetry(cx);
2171 }
2172 }
2173
2174 cx.emit(ThreadEvent::ToolFinished {
2175 tool_use_id,
2176 pending_tool_use,
2177 });
2178 }
2179
2180 /// Cancels the last pending completion, if there are any pending.
2181 ///
2182 /// Returns whether a completion was canceled.
2183 pub fn cancel_last_completion(
2184 &mut self,
2185 window: Option<AnyWindowHandle>,
2186 cx: &mut Context<Self>,
2187 ) -> bool {
2188 let mut canceled = self.pending_completions.pop().is_some();
2189
2190 for pending_tool_use in self.tool_use.cancel_pending() {
2191 canceled = true;
2192 self.tool_finished(
2193 pending_tool_use.id.clone(),
2194 Some(pending_tool_use),
2195 true,
2196 window,
2197 cx,
2198 );
2199 }
2200
2201 self.finalize_pending_checkpoint(cx);
2202
2203 if canceled {
2204 cx.emit(ThreadEvent::CompletionCanceled);
2205 }
2206
2207 canceled
2208 }
2209
2210 /// Signals that any in-progress editing should be canceled.
2211 ///
2212 /// This method is used to notify listeners (like ActiveThread) that
2213 /// they should cancel any editing operations.
2214 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2215 cx.emit(ThreadEvent::CancelEditing);
2216 }
2217
2218 pub fn feedback(&self) -> Option<ThreadFeedback> {
2219 self.feedback
2220 }
2221
2222 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2223 self.message_feedback.get(&message_id).copied()
2224 }
2225
2226 pub fn report_message_feedback(
2227 &mut self,
2228 message_id: MessageId,
2229 feedback: ThreadFeedback,
2230 cx: &mut Context<Self>,
2231 ) -> Task<Result<()>> {
2232 if self.message_feedback.get(&message_id) == Some(&feedback) {
2233 return Task::ready(Ok(()));
2234 }
2235
2236 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2237 let serialized_thread = self.serialize(cx);
2238 let thread_id = self.id().clone();
2239 let client = self.project.read(cx).client();
2240
2241 let enabled_tool_names: Vec<String> = self
2242 .tools()
2243 .read(cx)
2244 .enabled_tools(cx)
2245 .iter()
2246 .map(|tool| tool.name())
2247 .collect();
2248
2249 self.message_feedback.insert(message_id, feedback);
2250
2251 cx.notify();
2252
2253 let message_content = self
2254 .message(message_id)
2255 .map(|msg| msg.to_string())
2256 .unwrap_or_default();
2257
2258 cx.background_spawn(async move {
2259 let final_project_snapshot = final_project_snapshot.await;
2260 let serialized_thread = serialized_thread.await?;
2261 let thread_data =
2262 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2263
2264 let rating = match feedback {
2265 ThreadFeedback::Positive => "positive",
2266 ThreadFeedback::Negative => "negative",
2267 };
2268 telemetry::event!(
2269 "Assistant Thread Rated",
2270 rating,
2271 thread_id,
2272 enabled_tool_names,
2273 message_id = message_id.0,
2274 message_content,
2275 thread_data,
2276 final_project_snapshot
2277 );
2278 client.telemetry().flush_events().await;
2279
2280 Ok(())
2281 })
2282 }
2283
2284 pub fn report_feedback(
2285 &mut self,
2286 feedback: ThreadFeedback,
2287 cx: &mut Context<Self>,
2288 ) -> Task<Result<()>> {
2289 let last_assistant_message_id = self
2290 .messages
2291 .iter()
2292 .rev()
2293 .find(|msg| msg.role == Role::Assistant)
2294 .map(|msg| msg.id);
2295
2296 if let Some(message_id) = last_assistant_message_id {
2297 self.report_message_feedback(message_id, feedback, cx)
2298 } else {
2299 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2300 let serialized_thread = self.serialize(cx);
2301 let thread_id = self.id().clone();
2302 let client = self.project.read(cx).client();
2303 self.feedback = Some(feedback);
2304 cx.notify();
2305
2306 cx.background_spawn(async move {
2307 let final_project_snapshot = final_project_snapshot.await;
2308 let serialized_thread = serialized_thread.await?;
2309 let thread_data = serde_json::to_value(serialized_thread)
2310 .unwrap_or_else(|_| serde_json::Value::Null);
2311
2312 let rating = match feedback {
2313 ThreadFeedback::Positive => "positive",
2314 ThreadFeedback::Negative => "negative",
2315 };
2316 telemetry::event!(
2317 "Assistant Thread Rated",
2318 rating,
2319 thread_id,
2320 thread_data,
2321 final_project_snapshot
2322 );
2323 client.telemetry().flush_events().await;
2324
2325 Ok(())
2326 })
2327 }
2328 }
2329
2330 /// Create a snapshot of the current project state including git information and unsaved buffers.
2331 fn project_snapshot(
2332 project: Entity<Project>,
2333 cx: &mut Context<Self>,
2334 ) -> Task<Arc<ProjectSnapshot>> {
2335 let git_store = project.read(cx).git_store().clone();
2336 let worktree_snapshots: Vec<_> = project
2337 .read(cx)
2338 .visible_worktrees(cx)
2339 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2340 .collect();
2341
2342 cx.spawn(async move |_, cx| {
2343 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2344
2345 let mut unsaved_buffers = Vec::new();
2346 cx.update(|app_cx| {
2347 let buffer_store = project.read(app_cx).buffer_store();
2348 for buffer_handle in buffer_store.read(app_cx).buffers() {
2349 let buffer = buffer_handle.read(app_cx);
2350 if buffer.is_dirty() {
2351 if let Some(file) = buffer.file() {
2352 let path = file.path().to_string_lossy().to_string();
2353 unsaved_buffers.push(path);
2354 }
2355 }
2356 }
2357 })
2358 .ok();
2359
2360 Arc::new(ProjectSnapshot {
2361 worktree_snapshots,
2362 unsaved_buffer_paths: unsaved_buffers,
2363 timestamp: Utc::now(),
2364 })
2365 })
2366 }
2367
2368 fn worktree_snapshot(
2369 worktree: Entity<project::Worktree>,
2370 git_store: Entity<GitStore>,
2371 cx: &App,
2372 ) -> Task<WorktreeSnapshot> {
2373 cx.spawn(async move |cx| {
2374 // Get worktree path and snapshot
2375 let worktree_info = cx.update(|app_cx| {
2376 let worktree = worktree.read(app_cx);
2377 let path = worktree.abs_path().to_string_lossy().to_string();
2378 let snapshot = worktree.snapshot();
2379 (path, snapshot)
2380 });
2381
2382 let Ok((worktree_path, _snapshot)) = worktree_info else {
2383 return WorktreeSnapshot {
2384 worktree_path: String::new(),
2385 git_state: None,
2386 };
2387 };
2388
2389 let git_state = git_store
2390 .update(cx, |git_store, cx| {
2391 git_store
2392 .repositories()
2393 .values()
2394 .find(|repo| {
2395 repo.read(cx)
2396 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2397 .is_some()
2398 })
2399 .cloned()
2400 })
2401 .ok()
2402 .flatten()
2403 .map(|repo| {
2404 repo.update(cx, |repo, _| {
2405 let current_branch =
2406 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2407 repo.send_job(None, |state, _| async move {
2408 let RepositoryState::Local { backend, .. } = state else {
2409 return GitState {
2410 remote_url: None,
2411 head_sha: None,
2412 current_branch,
2413 diff: None,
2414 };
2415 };
2416
2417 let remote_url = backend.remote_url("origin");
2418 let head_sha = backend.head_sha().await;
2419 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2420
2421 GitState {
2422 remote_url,
2423 head_sha,
2424 current_branch,
2425 diff,
2426 }
2427 })
2428 })
2429 });
2430
2431 let git_state = match git_state {
2432 Some(git_state) => match git_state.ok() {
2433 Some(git_state) => git_state.await.ok(),
2434 None => None,
2435 },
2436 None => None,
2437 };
2438
2439 WorktreeSnapshot {
2440 worktree_path,
2441 git_state,
2442 }
2443 })
2444 }
2445
2446 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2447 let mut markdown = Vec::new();
2448
2449 let summary = self.summary().or_default();
2450 writeln!(markdown, "# {summary}\n")?;
2451
2452 for message in self.messages() {
2453 writeln!(
2454 markdown,
2455 "## {role}\n",
2456 role = match message.role {
2457 Role::User => "User",
2458 Role::Assistant => "Agent",
2459 Role::System => "System",
2460 }
2461 )?;
2462
2463 if !message.loaded_context.text.is_empty() {
2464 writeln!(markdown, "{}", message.loaded_context.text)?;
2465 }
2466
2467 if !message.loaded_context.images.is_empty() {
2468 writeln!(
2469 markdown,
2470 "\n{} images attached as context.\n",
2471 message.loaded_context.images.len()
2472 )?;
2473 }
2474
2475 for segment in &message.segments {
2476 match segment {
2477 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2478 MessageSegment::Thinking { text, .. } => {
2479 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2480 }
2481 MessageSegment::RedactedThinking(_) => {}
2482 }
2483 }
2484
2485 for tool_use in self.tool_uses_for_message(message.id, cx) {
2486 writeln!(
2487 markdown,
2488 "**Use Tool: {} ({})**",
2489 tool_use.name, tool_use.id
2490 )?;
2491 writeln!(markdown, "```json")?;
2492 writeln!(
2493 markdown,
2494 "{}",
2495 serde_json::to_string_pretty(&tool_use.input)?
2496 )?;
2497 writeln!(markdown, "```")?;
2498 }
2499
2500 for tool_result in self.tool_results_for_message(message.id) {
2501 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2502 if tool_result.is_error {
2503 write!(markdown, " (Error)")?;
2504 }
2505
2506 writeln!(markdown, "**\n")?;
2507 match &tool_result.content {
2508 LanguageModelToolResultContent::Text(str) => {
2509 writeln!(markdown, "{}", str)?;
2510 }
2511 LanguageModelToolResultContent::Image(image) => {
2512 writeln!(markdown, "", image.source)?;
2513 }
2514 }
2515
2516 if let Some(output) = tool_result.output.as_ref() {
2517 writeln!(
2518 markdown,
2519 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2520 serde_json::to_string_pretty(output)?
2521 )?;
2522 }
2523 }
2524 }
2525
2526 Ok(String::from_utf8_lossy(&markdown).to_string())
2527 }
2528
2529 pub fn keep_edits_in_range(
2530 &mut self,
2531 buffer: Entity<language::Buffer>,
2532 buffer_range: Range<language::Anchor>,
2533 cx: &mut Context<Self>,
2534 ) {
2535 self.action_log.update(cx, |action_log, cx| {
2536 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2537 });
2538 }
2539
2540 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2541 self.action_log
2542 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2543 }
2544
2545 pub fn reject_edits_in_ranges(
2546 &mut self,
2547 buffer: Entity<language::Buffer>,
2548 buffer_ranges: Vec<Range<language::Anchor>>,
2549 cx: &mut Context<Self>,
2550 ) -> Task<Result<()>> {
2551 self.action_log.update(cx, |action_log, cx| {
2552 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2553 })
2554 }
2555
2556 pub fn action_log(&self) -> &Entity<ActionLog> {
2557 &self.action_log
2558 }
2559
2560 pub fn project(&self) -> &Entity<Project> {
2561 &self.project
2562 }
2563
2564 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2565 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2566 return;
2567 }
2568
2569 let now = Instant::now();
2570 if let Some(last) = self.last_auto_capture_at {
2571 if now.duration_since(last).as_secs() < 10 {
2572 return;
2573 }
2574 }
2575
2576 self.last_auto_capture_at = Some(now);
2577
2578 let thread_id = self.id().clone();
2579 let github_login = self
2580 .project
2581 .read(cx)
2582 .user_store()
2583 .read(cx)
2584 .current_user()
2585 .map(|user| user.github_login.clone());
2586 let client = self.project.read(cx).client();
2587 let serialize_task = self.serialize(cx);
2588
2589 cx.background_executor()
2590 .spawn(async move {
2591 if let Ok(serialized_thread) = serialize_task.await {
2592 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2593 telemetry::event!(
2594 "Agent Thread Auto-Captured",
2595 thread_id = thread_id.to_string(),
2596 thread_data = thread_data,
2597 auto_capture_reason = "tracked_user",
2598 github_login = github_login
2599 );
2600
2601 client.telemetry().flush_events().await;
2602 }
2603 }
2604 })
2605 .detach();
2606 }
2607
2608 pub fn cumulative_token_usage(&self) -> TokenUsage {
2609 self.cumulative_token_usage
2610 }
2611
2612 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2613 let Some(model) = self.configured_model.as_ref() else {
2614 return TotalTokenUsage::default();
2615 };
2616
2617 let max = model.model.max_token_count();
2618
2619 let index = self
2620 .messages
2621 .iter()
2622 .position(|msg| msg.id == message_id)
2623 .unwrap_or(0);
2624
2625 if index == 0 {
2626 return TotalTokenUsage { total: 0, max };
2627 }
2628
2629 let token_usage = &self
2630 .request_token_usage
2631 .get(index - 1)
2632 .cloned()
2633 .unwrap_or_default();
2634
2635 TotalTokenUsage {
2636 total: token_usage.total_tokens() as usize,
2637 max,
2638 }
2639 }
2640
2641 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2642 let model = self.configured_model.as_ref()?;
2643
2644 let max = model.model.max_token_count();
2645
2646 if let Some(exceeded_error) = &self.exceeded_window_error {
2647 if model.model.id() == exceeded_error.model_id {
2648 return Some(TotalTokenUsage {
2649 total: exceeded_error.token_count,
2650 max,
2651 });
2652 }
2653 }
2654
2655 let total = self
2656 .token_usage_at_last_message()
2657 .unwrap_or_default()
2658 .total_tokens() as usize;
2659
2660 Some(TotalTokenUsage { total, max })
2661 }
2662
2663 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2664 self.request_token_usage
2665 .get(self.messages.len().saturating_sub(1))
2666 .or_else(|| self.request_token_usage.last())
2667 .cloned()
2668 }
2669
2670 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2671 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2672 self.request_token_usage
2673 .resize(self.messages.len(), placeholder);
2674
2675 if let Some(last) = self.request_token_usage.last_mut() {
2676 *last = token_usage;
2677 }
2678 }
2679
2680 pub fn deny_tool_use(
2681 &mut self,
2682 tool_use_id: LanguageModelToolUseId,
2683 tool_name: Arc<str>,
2684 window: Option<AnyWindowHandle>,
2685 cx: &mut Context<Self>,
2686 ) {
2687 let err = Err(anyhow::anyhow!(
2688 "Permission to run tool action denied by user"
2689 ));
2690
2691 self.tool_use.insert_tool_output(
2692 tool_use_id.clone(),
2693 tool_name,
2694 err,
2695 self.configured_model.as_ref(),
2696 );
2697 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2698 }
2699}
2700
2701#[derive(Debug, Clone, Error)]
2702pub enum ThreadError {
2703 #[error("Payment required")]
2704 PaymentRequired,
2705 #[error("Model request limit reached")]
2706 ModelRequestLimitReached { plan: Plan },
2707 #[error("Message {header}: {message}")]
2708 Message {
2709 header: SharedString,
2710 message: SharedString,
2711 },
2712}
2713
2714#[derive(Debug, Clone)]
2715pub enum ThreadEvent {
2716 ShowError(ThreadError),
2717 StreamedCompletion,
2718 ReceivedTextChunk,
2719 NewRequest,
2720 StreamedAssistantText(MessageId, String),
2721 StreamedAssistantThinking(MessageId, String),
2722 StreamedToolUse {
2723 tool_use_id: LanguageModelToolUseId,
2724 ui_text: Arc<str>,
2725 input: serde_json::Value,
2726 },
2727 MissingToolUse {
2728 tool_use_id: LanguageModelToolUseId,
2729 ui_text: Arc<str>,
2730 },
2731 InvalidToolInput {
2732 tool_use_id: LanguageModelToolUseId,
2733 ui_text: Arc<str>,
2734 invalid_input_json: Arc<str>,
2735 },
2736 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2737 MessageAdded(MessageId),
2738 MessageEdited(MessageId),
2739 MessageDeleted(MessageId),
2740 SummaryGenerated,
2741 SummaryChanged,
2742 UsePendingTools {
2743 tool_uses: Vec<PendingToolUse>,
2744 },
2745 ToolFinished {
2746 #[allow(unused)]
2747 tool_use_id: LanguageModelToolUseId,
2748 /// The pending tool use that corresponds to this tool.
2749 pending_tool_use: Option<PendingToolUse>,
2750 },
2751 CheckpointChanged,
2752 ToolConfirmationNeeded,
2753 CancelEditing,
2754 CompletionCanceled,
2755}
2756
2757impl EventEmitter<ThreadEvent> for Thread {}
2758
2759struct PendingCompletion {
2760 id: usize,
2761 queue_state: QueueState,
2762 _task: Task<()>,
2763}
2764
2765#[cfg(test)]
2766mod tests {
2767 use super::*;
2768 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2769 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2770 use assistant_tool::ToolRegistry;
2771 use editor::EditorSettings;
2772 use gpui::TestAppContext;
2773 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2774 use project::{FakeFs, Project};
2775 use prompt_store::PromptBuilder;
2776 use serde_json::json;
2777 use settings::{Settings, SettingsStore};
2778 use std::sync::Arc;
2779 use theme::ThemeSettings;
2780 use util::path;
2781 use workspace::Workspace;
2782
2783 #[gpui::test]
2784 async fn test_message_with_context(cx: &mut TestAppContext) {
2785 init_test_settings(cx);
2786
2787 let project = create_test_project(
2788 cx,
2789 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2790 )
2791 .await;
2792
2793 let (_workspace, _thread_store, thread, context_store, model) =
2794 setup_test_environment(cx, project.clone()).await;
2795
2796 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2797 .await
2798 .unwrap();
2799
2800 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2801 let loaded_context = cx
2802 .update(|cx| load_context(vec![context], &project, &None, cx))
2803 .await;
2804
2805 // Insert user message with context
2806 let message_id = thread.update(cx, |thread, cx| {
2807 thread.insert_user_message(
2808 "Please explain this code",
2809 loaded_context,
2810 None,
2811 Vec::new(),
2812 cx,
2813 )
2814 });
2815
2816 // Check content and context in message object
2817 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2818
2819 // Use different path format strings based on platform for the test
2820 #[cfg(windows)]
2821 let path_part = r"test\code.rs";
2822 #[cfg(not(windows))]
2823 let path_part = "test/code.rs";
2824
2825 let expected_context = format!(
2826 r#"
2827<context>
2828The following items were attached by the user. They are up-to-date and don't need to be re-read.
2829
2830<files>
2831```rs {path_part}
2832fn main() {{
2833 println!("Hello, world!");
2834}}
2835```
2836</files>
2837</context>
2838"#
2839 );
2840
2841 assert_eq!(message.role, Role::User);
2842 assert_eq!(message.segments.len(), 1);
2843 assert_eq!(
2844 message.segments[0],
2845 MessageSegment::Text("Please explain this code".to_string())
2846 );
2847 assert_eq!(message.loaded_context.text, expected_context);
2848
2849 // Check message in request
2850 let request = thread.update(cx, |thread, cx| {
2851 thread.to_completion_request(model.clone(), cx)
2852 });
2853
2854 assert_eq!(request.messages.len(), 2);
2855 let expected_full_message = format!("{}Please explain this code", expected_context);
2856 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2857 }
2858
2859 #[gpui::test]
2860 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2861 init_test_settings(cx);
2862
2863 let project = create_test_project(
2864 cx,
2865 json!({
2866 "file1.rs": "fn function1() {}\n",
2867 "file2.rs": "fn function2() {}\n",
2868 "file3.rs": "fn function3() {}\n",
2869 "file4.rs": "fn function4() {}\n",
2870 }),
2871 )
2872 .await;
2873
2874 let (_, _thread_store, thread, context_store, model) =
2875 setup_test_environment(cx, project.clone()).await;
2876
2877 // First message with context 1
2878 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2879 .await
2880 .unwrap();
2881 let new_contexts = context_store.update(cx, |store, cx| {
2882 store.new_context_for_thread(thread.read(cx), None)
2883 });
2884 assert_eq!(new_contexts.len(), 1);
2885 let loaded_context = cx
2886 .update(|cx| load_context(new_contexts, &project, &None, cx))
2887 .await;
2888 let message1_id = thread.update(cx, |thread, cx| {
2889 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2890 });
2891
2892 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2893 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2894 .await
2895 .unwrap();
2896 let new_contexts = context_store.update(cx, |store, cx| {
2897 store.new_context_for_thread(thread.read(cx), None)
2898 });
2899 assert_eq!(new_contexts.len(), 1);
2900 let loaded_context = cx
2901 .update(|cx| load_context(new_contexts, &project, &None, cx))
2902 .await;
2903 let message2_id = thread.update(cx, |thread, cx| {
2904 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2905 });
2906
2907 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2908 //
2909 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2910 .await
2911 .unwrap();
2912 let new_contexts = context_store.update(cx, |store, cx| {
2913 store.new_context_for_thread(thread.read(cx), None)
2914 });
2915 assert_eq!(new_contexts.len(), 1);
2916 let loaded_context = cx
2917 .update(|cx| load_context(new_contexts, &project, &None, cx))
2918 .await;
2919 let message3_id = thread.update(cx, |thread, cx| {
2920 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2921 });
2922
2923 // Check what contexts are included in each message
2924 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2925 (
2926 thread.message(message1_id).unwrap().clone(),
2927 thread.message(message2_id).unwrap().clone(),
2928 thread.message(message3_id).unwrap().clone(),
2929 )
2930 });
2931
2932 // First message should include context 1
2933 assert!(message1.loaded_context.text.contains("file1.rs"));
2934
2935 // Second message should include only context 2 (not 1)
2936 assert!(!message2.loaded_context.text.contains("file1.rs"));
2937 assert!(message2.loaded_context.text.contains("file2.rs"));
2938
2939 // Third message should include only context 3 (not 1 or 2)
2940 assert!(!message3.loaded_context.text.contains("file1.rs"));
2941 assert!(!message3.loaded_context.text.contains("file2.rs"));
2942 assert!(message3.loaded_context.text.contains("file3.rs"));
2943
2944 // Check entire request to make sure all contexts are properly included
2945 let request = thread.update(cx, |thread, cx| {
2946 thread.to_completion_request(model.clone(), cx)
2947 });
2948
2949 // The request should contain all 3 messages
2950 assert_eq!(request.messages.len(), 4);
2951
2952 // Check that the contexts are properly formatted in each message
2953 assert!(request.messages[1].string_contents().contains("file1.rs"));
2954 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2955 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2956
2957 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2958 assert!(request.messages[2].string_contents().contains("file2.rs"));
2959 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2960
2961 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2962 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2963 assert!(request.messages[3].string_contents().contains("file3.rs"));
2964
2965 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2966 .await
2967 .unwrap();
2968 let new_contexts = context_store.update(cx, |store, cx| {
2969 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2970 });
2971 assert_eq!(new_contexts.len(), 3);
2972 let loaded_context = cx
2973 .update(|cx| load_context(new_contexts, &project, &None, cx))
2974 .await
2975 .loaded_context;
2976
2977 assert!(!loaded_context.text.contains("file1.rs"));
2978 assert!(loaded_context.text.contains("file2.rs"));
2979 assert!(loaded_context.text.contains("file3.rs"));
2980 assert!(loaded_context.text.contains("file4.rs"));
2981
2982 let new_contexts = context_store.update(cx, |store, cx| {
2983 // Remove file4.rs
2984 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2985 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2986 });
2987 assert_eq!(new_contexts.len(), 2);
2988 let loaded_context = cx
2989 .update(|cx| load_context(new_contexts, &project, &None, cx))
2990 .await
2991 .loaded_context;
2992
2993 assert!(!loaded_context.text.contains("file1.rs"));
2994 assert!(loaded_context.text.contains("file2.rs"));
2995 assert!(loaded_context.text.contains("file3.rs"));
2996 assert!(!loaded_context.text.contains("file4.rs"));
2997
2998 let new_contexts = context_store.update(cx, |store, cx| {
2999 // Remove file3.rs
3000 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3001 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3002 });
3003 assert_eq!(new_contexts.len(), 1);
3004 let loaded_context = cx
3005 .update(|cx| load_context(new_contexts, &project, &None, cx))
3006 .await
3007 .loaded_context;
3008
3009 assert!(!loaded_context.text.contains("file1.rs"));
3010 assert!(loaded_context.text.contains("file2.rs"));
3011 assert!(!loaded_context.text.contains("file3.rs"));
3012 assert!(!loaded_context.text.contains("file4.rs"));
3013 }
3014
3015 #[gpui::test]
3016 async fn test_message_without_files(cx: &mut TestAppContext) {
3017 init_test_settings(cx);
3018
3019 let project = create_test_project(
3020 cx,
3021 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3022 )
3023 .await;
3024
3025 let (_, _thread_store, thread, _context_store, model) =
3026 setup_test_environment(cx, project.clone()).await;
3027
3028 // Insert user message without any context (empty context vector)
3029 let message_id = thread.update(cx, |thread, cx| {
3030 thread.insert_user_message(
3031 "What is the best way to learn Rust?",
3032 ContextLoadResult::default(),
3033 None,
3034 Vec::new(),
3035 cx,
3036 )
3037 });
3038
3039 // Check content and context in message object
3040 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3041
3042 // Context should be empty when no files are included
3043 assert_eq!(message.role, Role::User);
3044 assert_eq!(message.segments.len(), 1);
3045 assert_eq!(
3046 message.segments[0],
3047 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3048 );
3049 assert_eq!(message.loaded_context.text, "");
3050
3051 // Check message in request
3052 let request = thread.update(cx, |thread, cx| {
3053 thread.to_completion_request(model.clone(), cx)
3054 });
3055
3056 assert_eq!(request.messages.len(), 2);
3057 assert_eq!(
3058 request.messages[1].string_contents(),
3059 "What is the best way to learn Rust?"
3060 );
3061
3062 // Add second message, also without context
3063 let message2_id = thread.update(cx, |thread, cx| {
3064 thread.insert_user_message(
3065 "Are there any good books?",
3066 ContextLoadResult::default(),
3067 None,
3068 Vec::new(),
3069 cx,
3070 )
3071 });
3072
3073 let message2 =
3074 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3075 assert_eq!(message2.loaded_context.text, "");
3076
3077 // Check that both messages appear in the request
3078 let request = thread.update(cx, |thread, cx| {
3079 thread.to_completion_request(model.clone(), cx)
3080 });
3081
3082 assert_eq!(request.messages.len(), 3);
3083 assert_eq!(
3084 request.messages[1].string_contents(),
3085 "What is the best way to learn Rust?"
3086 );
3087 assert_eq!(
3088 request.messages[2].string_contents(),
3089 "Are there any good books?"
3090 );
3091 }
3092
3093 #[gpui::test]
3094 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3095 init_test_settings(cx);
3096
3097 let project = create_test_project(
3098 cx,
3099 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3100 )
3101 .await;
3102
3103 let (_workspace, _thread_store, thread, context_store, model) =
3104 setup_test_environment(cx, project.clone()).await;
3105
3106 // Open buffer and add it to context
3107 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3108 .await
3109 .unwrap();
3110
3111 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3112 let loaded_context = cx
3113 .update(|cx| load_context(vec![context], &project, &None, cx))
3114 .await;
3115
3116 // Insert user message with the buffer as context
3117 thread.update(cx, |thread, cx| {
3118 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3119 });
3120
3121 // Create a request and check that it doesn't have a stale buffer warning yet
3122 let initial_request = thread.update(cx, |thread, cx| {
3123 thread.to_completion_request(model.clone(), cx)
3124 });
3125
3126 // Make sure we don't have a stale file warning yet
3127 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3128 msg.string_contents()
3129 .contains("These files changed since last read:")
3130 });
3131 assert!(
3132 !has_stale_warning,
3133 "Should not have stale buffer warning before buffer is modified"
3134 );
3135
3136 // Modify the buffer
3137 buffer.update(cx, |buffer, cx| {
3138 // Find a position at the end of line 1
3139 buffer.edit(
3140 [(1..1, "\n println!(\"Added a new line\");\n")],
3141 None,
3142 cx,
3143 );
3144 });
3145
3146 // Insert another user message without context
3147 thread.update(cx, |thread, cx| {
3148 thread.insert_user_message(
3149 "What does the code do now?",
3150 ContextLoadResult::default(),
3151 None,
3152 Vec::new(),
3153 cx,
3154 )
3155 });
3156
3157 // Create a new request and check for the stale buffer warning
3158 let new_request = thread.update(cx, |thread, cx| {
3159 thread.to_completion_request(model.clone(), cx)
3160 });
3161
3162 // We should have a stale file warning as the last message
3163 let last_message = new_request
3164 .messages
3165 .last()
3166 .expect("Request should have messages");
3167
3168 // The last message should be the stale buffer notification
3169 assert_eq!(last_message.role, Role::User);
3170
3171 // Check the exact content of the message
3172 let expected_content = "These files changed since last read:\n- code.rs\n";
3173 assert_eq!(
3174 last_message.string_contents(),
3175 expected_content,
3176 "Last message should be exactly the stale buffer notification"
3177 );
3178 }
3179
3180 #[gpui::test]
3181 async fn test_temperature_setting(cx: &mut TestAppContext) {
3182 init_test_settings(cx);
3183
3184 let project = create_test_project(
3185 cx,
3186 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3187 )
3188 .await;
3189
3190 let (_workspace, _thread_store, thread, _context_store, model) =
3191 setup_test_environment(cx, project.clone()).await;
3192
3193 // Both model and provider
3194 cx.update(|cx| {
3195 AssistantSettings::override_global(
3196 AssistantSettings {
3197 model_parameters: vec![LanguageModelParameters {
3198 provider: Some(model.provider_id().0.to_string().into()),
3199 model: Some(model.id().0.clone()),
3200 temperature: Some(0.66),
3201 }],
3202 ..AssistantSettings::get_global(cx).clone()
3203 },
3204 cx,
3205 );
3206 });
3207
3208 let request = thread.update(cx, |thread, cx| {
3209 thread.to_completion_request(model.clone(), cx)
3210 });
3211 assert_eq!(request.temperature, Some(0.66));
3212
3213 // Only model
3214 cx.update(|cx| {
3215 AssistantSettings::override_global(
3216 AssistantSettings {
3217 model_parameters: vec![LanguageModelParameters {
3218 provider: None,
3219 model: Some(model.id().0.clone()),
3220 temperature: Some(0.66),
3221 }],
3222 ..AssistantSettings::get_global(cx).clone()
3223 },
3224 cx,
3225 );
3226 });
3227
3228 let request = thread.update(cx, |thread, cx| {
3229 thread.to_completion_request(model.clone(), cx)
3230 });
3231 assert_eq!(request.temperature, Some(0.66));
3232
3233 // Only provider
3234 cx.update(|cx| {
3235 AssistantSettings::override_global(
3236 AssistantSettings {
3237 model_parameters: vec![LanguageModelParameters {
3238 provider: Some(model.provider_id().0.to_string().into()),
3239 model: None,
3240 temperature: Some(0.66),
3241 }],
3242 ..AssistantSettings::get_global(cx).clone()
3243 },
3244 cx,
3245 );
3246 });
3247
3248 let request = thread.update(cx, |thread, cx| {
3249 thread.to_completion_request(model.clone(), cx)
3250 });
3251 assert_eq!(request.temperature, Some(0.66));
3252
3253 // Same model name, different provider
3254 cx.update(|cx| {
3255 AssistantSettings::override_global(
3256 AssistantSettings {
3257 model_parameters: vec![LanguageModelParameters {
3258 provider: Some("anthropic".into()),
3259 model: Some(model.id().0.clone()),
3260 temperature: Some(0.66),
3261 }],
3262 ..AssistantSettings::get_global(cx).clone()
3263 },
3264 cx,
3265 );
3266 });
3267
3268 let request = thread.update(cx, |thread, cx| {
3269 thread.to_completion_request(model.clone(), cx)
3270 });
3271 assert_eq!(request.temperature, None);
3272 }
3273
3274 #[gpui::test]
3275 async fn test_thread_summary(cx: &mut TestAppContext) {
3276 init_test_settings(cx);
3277
3278 let project = create_test_project(cx, json!({})).await;
3279
3280 let (_, _thread_store, thread, _context_store, model) =
3281 setup_test_environment(cx, project.clone()).await;
3282
3283 // Initial state should be pending
3284 thread.read_with(cx, |thread, _| {
3285 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3286 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3287 });
3288
3289 // Manually setting the summary should not be allowed in this state
3290 thread.update(cx, |thread, cx| {
3291 thread.set_summary("This should not work", cx);
3292 });
3293
3294 thread.read_with(cx, |thread, _| {
3295 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3296 });
3297
3298 // Send a message
3299 thread.update(cx, |thread, cx| {
3300 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3301 thread.send_to_model(model.clone(), None, cx);
3302 });
3303
3304 let fake_model = model.as_fake();
3305 simulate_successful_response(&fake_model, cx);
3306
3307 // Should start generating summary when there are >= 2 messages
3308 thread.read_with(cx, |thread, _| {
3309 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3310 });
3311
3312 // Should not be able to set the summary while generating
3313 thread.update(cx, |thread, cx| {
3314 thread.set_summary("This should not work either", cx);
3315 });
3316
3317 thread.read_with(cx, |thread, _| {
3318 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3319 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3320 });
3321
3322 cx.run_until_parked();
3323 fake_model.stream_last_completion_response("Brief".into());
3324 fake_model.stream_last_completion_response(" Introduction".into());
3325 fake_model.end_last_completion_stream();
3326 cx.run_until_parked();
3327
3328 // Summary should be set
3329 thread.read_with(cx, |thread, _| {
3330 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3331 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3332 });
3333
3334 // Now we should be able to set a summary
3335 thread.update(cx, |thread, cx| {
3336 thread.set_summary("Brief Intro", cx);
3337 });
3338
3339 thread.read_with(cx, |thread, _| {
3340 assert_eq!(thread.summary().or_default(), "Brief Intro");
3341 });
3342
3343 // Test setting an empty summary (should default to DEFAULT)
3344 thread.update(cx, |thread, cx| {
3345 thread.set_summary("", cx);
3346 });
3347
3348 thread.read_with(cx, |thread, _| {
3349 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3350 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3351 });
3352 }
3353
3354 #[gpui::test]
3355 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3356 init_test_settings(cx);
3357
3358 let project = create_test_project(cx, json!({})).await;
3359
3360 let (_, _thread_store, thread, _context_store, model) =
3361 setup_test_environment(cx, project.clone()).await;
3362
3363 test_summarize_error(&model, &thread, cx);
3364
3365 // Now we should be able to set a summary
3366 thread.update(cx, |thread, cx| {
3367 thread.set_summary("Brief Intro", cx);
3368 });
3369
3370 thread.read_with(cx, |thread, _| {
3371 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3372 assert_eq!(thread.summary().or_default(), "Brief Intro");
3373 });
3374 }
3375
3376 #[gpui::test]
3377 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3378 init_test_settings(cx);
3379
3380 let project = create_test_project(cx, json!({})).await;
3381
3382 let (_, _thread_store, thread, _context_store, model) =
3383 setup_test_environment(cx, project.clone()).await;
3384
3385 test_summarize_error(&model, &thread, cx);
3386
3387 // Sending another message should not trigger another summarize request
3388 thread.update(cx, |thread, cx| {
3389 thread.insert_user_message(
3390 "How are you?",
3391 ContextLoadResult::default(),
3392 None,
3393 vec![],
3394 cx,
3395 );
3396 thread.send_to_model(model.clone(), None, cx);
3397 });
3398
3399 let fake_model = model.as_fake();
3400 simulate_successful_response(&fake_model, cx);
3401
3402 thread.read_with(cx, |thread, _| {
3403 // State is still Error, not Generating
3404 assert!(matches!(thread.summary(), ThreadSummary::Error));
3405 });
3406
3407 // But the summarize request can be invoked manually
3408 thread.update(cx, |thread, cx| {
3409 thread.summarize(cx);
3410 });
3411
3412 thread.read_with(cx, |thread, _| {
3413 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3414 });
3415
3416 cx.run_until_parked();
3417 fake_model.stream_last_completion_response("A successful summary".into());
3418 fake_model.end_last_completion_stream();
3419 cx.run_until_parked();
3420
3421 thread.read_with(cx, |thread, _| {
3422 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3423 assert_eq!(thread.summary().or_default(), "A successful summary");
3424 });
3425 }
3426
3427 fn test_summarize_error(
3428 model: &Arc<dyn LanguageModel>,
3429 thread: &Entity<Thread>,
3430 cx: &mut TestAppContext,
3431 ) {
3432 thread.update(cx, |thread, cx| {
3433 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3434 thread.send_to_model(model.clone(), None, cx);
3435 });
3436
3437 let fake_model = model.as_fake();
3438 simulate_successful_response(&fake_model, cx);
3439
3440 thread.read_with(cx, |thread, _| {
3441 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3442 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3443 });
3444
3445 // Simulate summary request ending
3446 cx.run_until_parked();
3447 fake_model.end_last_completion_stream();
3448 cx.run_until_parked();
3449
3450 // State is set to Error and default message
3451 thread.read_with(cx, |thread, _| {
3452 assert!(matches!(thread.summary(), ThreadSummary::Error));
3453 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3454 });
3455 }
3456
3457 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3458 cx.run_until_parked();
3459 fake_model.stream_last_completion_response("Assistant response".into());
3460 fake_model.end_last_completion_stream();
3461 cx.run_until_parked();
3462 }
3463
3464 fn init_test_settings(cx: &mut TestAppContext) {
3465 cx.update(|cx| {
3466 let settings_store = SettingsStore::test(cx);
3467 cx.set_global(settings_store);
3468 language::init(cx);
3469 Project::init_settings(cx);
3470 AssistantSettings::register(cx);
3471 prompt_store::init(cx);
3472 thread_store::init(cx);
3473 workspace::init_settings(cx);
3474 language_model::init_settings(cx);
3475 ThemeSettings::register(cx);
3476 EditorSettings::register(cx);
3477 ToolRegistry::default_global(cx);
3478 });
3479 }
3480
3481 // Helper to create a test project with test files
3482 async fn create_test_project(
3483 cx: &mut TestAppContext,
3484 files: serde_json::Value,
3485 ) -> Entity<Project> {
3486 let fs = FakeFs::new(cx.executor());
3487 fs.insert_tree(path!("/test"), files).await;
3488 Project::test(fs, [path!("/test").as_ref()], cx).await
3489 }
3490
3491 async fn setup_test_environment(
3492 cx: &mut TestAppContext,
3493 project: Entity<Project>,
3494 ) -> (
3495 Entity<Workspace>,
3496 Entity<ThreadStore>,
3497 Entity<Thread>,
3498 Entity<ContextStore>,
3499 Arc<dyn LanguageModel>,
3500 ) {
3501 let (workspace, cx) =
3502 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3503
3504 let thread_store = cx
3505 .update(|_, cx| {
3506 ThreadStore::load(
3507 project.clone(),
3508 cx.new(|_| ToolWorkingSet::default()),
3509 None,
3510 Arc::new(PromptBuilder::new(None).unwrap()),
3511 cx,
3512 )
3513 })
3514 .await
3515 .unwrap();
3516
3517 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3518 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3519
3520 let provider = Arc::new(FakeLanguageModelProvider);
3521 let model = provider.test_model();
3522 let model: Arc<dyn LanguageModel> = Arc::new(model);
3523
3524 cx.update(|_, cx| {
3525 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3526 registry.set_default_model(
3527 Some(ConfiguredModel {
3528 provider: provider.clone(),
3529 model: model.clone(),
3530 }),
3531 cx,
3532 );
3533 registry.set_thread_summary_model(
3534 Some(ConfiguredModel {
3535 provider,
3536 model: model.clone(),
3537 }),
3538 cx,
3539 );
3540 })
3541 });
3542
3543 (workspace, thread_store, thread, context_store, model)
3544 }
3545
3546 async fn add_file_to_context(
3547 project: &Entity<Project>,
3548 context_store: &Entity<ContextStore>,
3549 path: &str,
3550 cx: &mut TestAppContext,
3551 ) -> Result<Entity<language::Buffer>> {
3552 let buffer_path = project
3553 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3554 .unwrap();
3555
3556 let buffer = project
3557 .update(cx, |project, cx| {
3558 project.open_buffer(buffer_path.clone(), cx)
3559 })
3560 .await
3561 .unwrap();
3562
3563 context_store.update(cx, |context_store, cx| {
3564 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3565 });
3566
3567 Ok(buffer)
3568 }
3569}