1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, TryFutureExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: Option<SharedString>,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362 configured_profile: Option<AgentProfile>,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct ExceededWindowError {
367 /// Model used when last message exceeded context window
368 model_id: LanguageModelId,
369 /// Token count including last message
370 token_count: usize,
371}
372
373impl Thread {
374 pub fn new(
375 project: Entity<Project>,
376 tools: Entity<ToolWorkingSet>,
377 prompt_builder: Arc<PromptBuilder>,
378 system_prompt: SharedProjectContext,
379 cx: &mut Context<Self>,
380 ) -> Self {
381 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
382 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
383 let assistant_settings = AssistantSettings::get_global(cx);
384 let profile_id = &assistant_settings.default_profile;
385 let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
386
387 Self {
388 id: ThreadId::new(),
389 updated_at: Utc::now(),
390 summary: None,
391 pending_summary: Task::ready(None),
392 detailed_summary_task: Task::ready(None),
393 detailed_summary_tx,
394 detailed_summary_rx,
395 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
396 messages: Vec::new(),
397 next_message_id: MessageId(0),
398 last_prompt_id: PromptId::new(),
399 project_context: system_prompt,
400 checkpoints_by_message: HashMap::default(),
401 completion_count: 0,
402 pending_completions: Vec::new(),
403 project: project.clone(),
404 prompt_builder,
405 tools: tools.clone(),
406 last_restore_checkpoint: None,
407 pending_checkpoint: None,
408 tool_use: ToolUseState::new(tools.clone()),
409 action_log: cx.new(|_| ActionLog::new(project.clone())),
410 initial_project_snapshot: {
411 let project_snapshot = Self::project_snapshot(project, cx);
412 cx.foreground_executor()
413 .spawn(async move { Some(project_snapshot.await) })
414 .shared()
415 },
416 request_token_usage: Vec::new(),
417 cumulative_token_usage: TokenUsage::default(),
418 exceeded_window_error: None,
419 last_usage: None,
420 tool_use_limit_reached: false,
421 feedback: None,
422 message_feedback: HashMap::default(),
423 last_auto_capture_at: None,
424 last_received_chunk_at: None,
425 request_callback: None,
426 remaining_turns: u32::MAX,
427 configured_model,
428 configured_profile,
429 }
430 }
431
432 pub fn deserialize(
433 id: ThreadId,
434 serialized: SerializedThread,
435 project: Entity<Project>,
436 tools: Entity<ToolWorkingSet>,
437 prompt_builder: Arc<PromptBuilder>,
438 project_context: SharedProjectContext,
439 window: &mut Window,
440 cx: &mut Context<Self>,
441 ) -> Self {
442 let next_message_id = MessageId(
443 serialized
444 .messages
445 .last()
446 .map(|message| message.id.0 + 1)
447 .unwrap_or(0),
448 );
449 let tool_use = ToolUseState::from_serialized_messages(
450 tools.clone(),
451 &serialized.messages,
452 project.clone(),
453 window,
454 cx,
455 );
456 let (detailed_summary_tx, detailed_summary_rx) =
457 postage::watch::channel_with(serialized.detailed_summary_state);
458
459 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
460 serialized
461 .model
462 .and_then(|model| {
463 let model = SelectedModel {
464 provider: model.provider.clone().into(),
465 model: model.model.clone().into(),
466 };
467 registry.select_model(&model, cx)
468 })
469 .or_else(|| registry.default_model())
470 });
471
472 let completion_mode = serialized
473 .completion_mode
474 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
475
476 let configured_profile = serialized.profile.and_then(|profile| {
477 AssistantSettings::get_global(cx)
478 .profiles
479 .get(&profile)
480 .cloned()
481 });
482
483 Self {
484 id,
485 updated_at: serialized.updated_at,
486 summary: Some(serialized.summary),
487 pending_summary: Task::ready(None),
488 detailed_summary_task: Task::ready(None),
489 detailed_summary_tx,
490 detailed_summary_rx,
491 completion_mode,
492 messages: serialized
493 .messages
494 .into_iter()
495 .map(|message| Message {
496 id: message.id,
497 role: message.role,
498 segments: message
499 .segments
500 .into_iter()
501 .map(|segment| match segment {
502 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
503 SerializedMessageSegment::Thinking { text, signature } => {
504 MessageSegment::Thinking { text, signature }
505 }
506 SerializedMessageSegment::RedactedThinking { data } => {
507 MessageSegment::RedactedThinking(data)
508 }
509 })
510 .collect(),
511 loaded_context: LoadedContext {
512 contexts: Vec::new(),
513 text: message.context,
514 images: Vec::new(),
515 },
516 creases: message
517 .creases
518 .into_iter()
519 .map(|crease| MessageCrease {
520 range: crease.start..crease.end,
521 metadata: CreaseMetadata {
522 icon_path: crease.icon_path,
523 label: crease.label,
524 },
525 context: None,
526 })
527 .collect(),
528 })
529 .collect(),
530 next_message_id,
531 last_prompt_id: PromptId::new(),
532 project_context,
533 checkpoints_by_message: HashMap::default(),
534 completion_count: 0,
535 pending_completions: Vec::new(),
536 last_restore_checkpoint: None,
537 pending_checkpoint: None,
538 project: project.clone(),
539 prompt_builder,
540 tools,
541 tool_use,
542 action_log: cx.new(|_| ActionLog::new(project)),
543 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
544 request_token_usage: serialized.request_token_usage,
545 cumulative_token_usage: serialized.cumulative_token_usage,
546 exceeded_window_error: None,
547 last_usage: None,
548 tool_use_limit_reached: false,
549 feedback: None,
550 message_feedback: HashMap::default(),
551 last_auto_capture_at: None,
552 last_received_chunk_at: None,
553 request_callback: None,
554 remaining_turns: u32::MAX,
555 configured_model,
556 configured_profile,
557 }
558 }
559
560 pub fn set_request_callback(
561 &mut self,
562 callback: impl 'static
563 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
564 ) {
565 self.request_callback = Some(Box::new(callback));
566 }
567
568 pub fn id(&self) -> &ThreadId {
569 &self.id
570 }
571
572 pub fn is_empty(&self) -> bool {
573 self.messages.is_empty()
574 }
575
576 pub fn updated_at(&self) -> DateTime<Utc> {
577 self.updated_at
578 }
579
580 pub fn touch_updated_at(&mut self) {
581 self.updated_at = Utc::now();
582 }
583
584 pub fn advance_prompt_id(&mut self) {
585 self.last_prompt_id = PromptId::new();
586 }
587
588 pub fn summary(&self) -> Option<SharedString> {
589 self.summary.clone()
590 }
591
592 pub fn project_context(&self) -> SharedProjectContext {
593 self.project_context.clone()
594 }
595
596 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
597 if self.configured_model.is_none() {
598 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
599 }
600 self.configured_model.clone()
601 }
602
603 pub fn configured_model(&self) -> Option<ConfiguredModel> {
604 self.configured_model.clone()
605 }
606
607 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
608 self.configured_model = model;
609 cx.notify();
610 }
611
612 pub fn configured_profile(&self) -> Option<AgentProfile> {
613 self.configured_profile.clone()
614 }
615
616 pub fn set_configured_profile(
617 &mut self,
618 profile: Option<AgentProfile>,
619 cx: &mut Context<Self>,
620 ) {
621 self.configured_profile = profile;
622 cx.notify();
623 }
624
625 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
626
627 pub fn summary_or_default(&self) -> SharedString {
628 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
629 }
630
631 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
632 let Some(current_summary) = &self.summary else {
633 // Don't allow setting summary until generated
634 return;
635 };
636
637 let mut new_summary = new_summary.into();
638
639 if new_summary.is_empty() {
640 new_summary = Self::DEFAULT_SUMMARY;
641 }
642
643 if current_summary != &new_summary {
644 self.summary = Some(new_summary);
645 cx.emit(ThreadEvent::SummaryChanged);
646 }
647 }
648
649 pub fn completion_mode(&self) -> CompletionMode {
650 self.completion_mode
651 }
652
653 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
654 self.completion_mode = mode;
655 }
656
657 pub fn message(&self, id: MessageId) -> Option<&Message> {
658 let index = self
659 .messages
660 .binary_search_by(|message| message.id.cmp(&id))
661 .ok()?;
662
663 self.messages.get(index)
664 }
665
666 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
667 self.messages.iter()
668 }
669
670 pub fn is_generating(&self) -> bool {
671 !self.pending_completions.is_empty() || !self.all_tools_finished()
672 }
673
674 /// Indicates whether streaming of language model events is stale.
675 /// When `is_generating()` is false, this method returns `None`.
676 pub fn is_generation_stale(&self) -> Option<bool> {
677 const STALE_THRESHOLD: u128 = 250;
678
679 self.last_received_chunk_at
680 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
681 }
682
683 fn received_chunk(&mut self) {
684 self.last_received_chunk_at = Some(Instant::now());
685 }
686
687 pub fn queue_state(&self) -> Option<QueueState> {
688 self.pending_completions
689 .first()
690 .map(|pending_completion| pending_completion.queue_state)
691 }
692
693 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
694 &self.tools
695 }
696
697 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
698 self.tool_use
699 .pending_tool_uses()
700 .into_iter()
701 .find(|tool_use| &tool_use.id == id)
702 }
703
704 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
705 self.tool_use
706 .pending_tool_uses()
707 .into_iter()
708 .filter(|tool_use| tool_use.status.needs_confirmation())
709 }
710
711 pub fn has_pending_tool_uses(&self) -> bool {
712 !self.tool_use.pending_tool_uses().is_empty()
713 }
714
715 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
716 self.checkpoints_by_message.get(&id).cloned()
717 }
718
719 pub fn restore_checkpoint(
720 &mut self,
721 checkpoint: ThreadCheckpoint,
722 cx: &mut Context<Self>,
723 ) -> Task<Result<()>> {
724 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
725 message_id: checkpoint.message_id,
726 });
727 cx.emit(ThreadEvent::CheckpointChanged);
728 cx.notify();
729
730 let git_store = self.project().read(cx).git_store().clone();
731 let restore = git_store.update(cx, |git_store, cx| {
732 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
733 });
734
735 cx.spawn(async move |this, cx| {
736 let result = restore.await;
737 this.update(cx, |this, cx| {
738 if let Err(err) = result.as_ref() {
739 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
740 message_id: checkpoint.message_id,
741 error: err.to_string(),
742 });
743 } else {
744 this.truncate(checkpoint.message_id, cx);
745 this.last_restore_checkpoint = None;
746 }
747 this.pending_checkpoint = None;
748 cx.emit(ThreadEvent::CheckpointChanged);
749 cx.notify();
750 })?;
751 result
752 })
753 }
754
755 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
756 let pending_checkpoint = if self.is_generating() {
757 return;
758 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
759 checkpoint
760 } else {
761 return;
762 };
763
764 let git_store = self.project.read(cx).git_store().clone();
765 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
766 cx.spawn(async move |this, cx| match final_checkpoint.await {
767 Ok(final_checkpoint) => {
768 let equal = git_store
769 .update(cx, |store, cx| {
770 store.compare_checkpoints(
771 pending_checkpoint.git_checkpoint.clone(),
772 final_checkpoint.clone(),
773 cx,
774 )
775 })?
776 .await
777 .unwrap_or(false);
778
779 if !equal {
780 this.update(cx, |this, cx| {
781 this.insert_checkpoint(pending_checkpoint, cx)
782 })?;
783 }
784
785 Ok(())
786 }
787 Err(_) => this.update(cx, |this, cx| {
788 this.insert_checkpoint(pending_checkpoint, cx)
789 }),
790 })
791 .detach();
792 }
793
794 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
795 self.checkpoints_by_message
796 .insert(checkpoint.message_id, checkpoint);
797 cx.emit(ThreadEvent::CheckpointChanged);
798 cx.notify();
799 }
800
801 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
802 self.last_restore_checkpoint.as_ref()
803 }
804
805 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
806 let Some(message_ix) = self
807 .messages
808 .iter()
809 .rposition(|message| message.id == message_id)
810 else {
811 return;
812 };
813 for deleted_message in self.messages.drain(message_ix..) {
814 self.checkpoints_by_message.remove(&deleted_message.id);
815 }
816 cx.notify();
817 }
818
819 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
820 self.messages
821 .iter()
822 .find(|message| message.id == id)
823 .into_iter()
824 .flat_map(|message| message.loaded_context.contexts.iter())
825 }
826
827 pub fn is_turn_end(&self, ix: usize) -> bool {
828 if self.messages.is_empty() {
829 return false;
830 }
831
832 if !self.is_generating() && ix == self.messages.len() - 1 {
833 return true;
834 }
835
836 let Some(message) = self.messages.get(ix) else {
837 return false;
838 };
839
840 if message.role != Role::Assistant {
841 return false;
842 }
843
844 self.messages
845 .get(ix + 1)
846 .and_then(|message| {
847 self.message(message.id)
848 .map(|next_message| next_message.role == Role::User)
849 })
850 .unwrap_or(false)
851 }
852
853 pub fn last_usage(&self) -> Option<RequestUsage> {
854 self.last_usage
855 }
856
857 pub fn tool_use_limit_reached(&self) -> bool {
858 self.tool_use_limit_reached
859 }
860
861 /// Returns whether all of the tool uses have finished running.
862 pub fn all_tools_finished(&self) -> bool {
863 // If the only pending tool uses left are the ones with errors, then
864 // that means that we've finished running all of the pending tools.
865 self.tool_use
866 .pending_tool_uses()
867 .iter()
868 .all(|tool_use| tool_use.status.is_error())
869 }
870
871 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
872 self.tool_use.tool_uses_for_message(id, cx)
873 }
874
875 pub fn tool_results_for_message(
876 &self,
877 assistant_message_id: MessageId,
878 ) -> Vec<&LanguageModelToolResult> {
879 self.tool_use.tool_results_for_message(assistant_message_id)
880 }
881
882 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
883 self.tool_use.tool_result(id)
884 }
885
886 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
887 Some(&self.tool_use.tool_result(id)?.content)
888 }
889
890 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
891 self.tool_use.tool_result_card(id).cloned()
892 }
893
894 /// Return tools that are both enabled and supported by the model
895 pub fn available_tools(
896 &self,
897 cx: &App,
898 model: Arc<dyn LanguageModel>,
899 ) -> Vec<LanguageModelRequestTool> {
900 if model.supports_tools() {
901 self.tools()
902 .read(cx)
903 .enabled_tools(cx)
904 .into_iter()
905 .filter_map(|tool| {
906 // Skip tools that cannot be supported
907 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
908 Some(LanguageModelRequestTool {
909 name: tool.name(),
910 description: tool.description(),
911 input_schema,
912 })
913 })
914 .collect()
915 } else {
916 Vec::default()
917 }
918 }
919
920 pub fn insert_user_message(
921 &mut self,
922 text: impl Into<String>,
923 loaded_context: ContextLoadResult,
924 git_checkpoint: Option<GitStoreCheckpoint>,
925 creases: Vec<MessageCrease>,
926 cx: &mut Context<Self>,
927 ) -> MessageId {
928 if !loaded_context.referenced_buffers.is_empty() {
929 self.action_log.update(cx, |log, cx| {
930 for buffer in loaded_context.referenced_buffers {
931 log.buffer_read(buffer, cx);
932 }
933 });
934 }
935
936 let message_id = self.insert_message(
937 Role::User,
938 vec![MessageSegment::Text(text.into())],
939 loaded_context.loaded_context,
940 creases,
941 cx,
942 );
943
944 if let Some(git_checkpoint) = git_checkpoint {
945 self.pending_checkpoint = Some(ThreadCheckpoint {
946 message_id,
947 git_checkpoint,
948 });
949 }
950
951 self.auto_capture_telemetry(cx);
952
953 message_id
954 }
955
956 pub fn insert_assistant_message(
957 &mut self,
958 segments: Vec<MessageSegment>,
959 cx: &mut Context<Self>,
960 ) -> MessageId {
961 self.insert_message(
962 Role::Assistant,
963 segments,
964 LoadedContext::default(),
965 Vec::new(),
966 cx,
967 )
968 }
969
970 pub fn insert_message(
971 &mut self,
972 role: Role,
973 segments: Vec<MessageSegment>,
974 loaded_context: LoadedContext,
975 creases: Vec<MessageCrease>,
976 cx: &mut Context<Self>,
977 ) -> MessageId {
978 let id = self.next_message_id.post_inc();
979 self.messages.push(Message {
980 id,
981 role,
982 segments,
983 loaded_context,
984 creases,
985 });
986 self.touch_updated_at();
987 cx.emit(ThreadEvent::MessageAdded(id));
988 id
989 }
990
991 pub fn edit_message(
992 &mut self,
993 id: MessageId,
994 new_role: Role,
995 new_segments: Vec<MessageSegment>,
996 loaded_context: Option<LoadedContext>,
997 cx: &mut Context<Self>,
998 ) -> bool {
999 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1000 return false;
1001 };
1002 message.role = new_role;
1003 message.segments = new_segments;
1004 if let Some(context) = loaded_context {
1005 message.loaded_context = context;
1006 }
1007 self.touch_updated_at();
1008 cx.emit(ThreadEvent::MessageEdited(id));
1009 true
1010 }
1011
1012 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1013 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1014 return false;
1015 };
1016 self.messages.remove(index);
1017 self.touch_updated_at();
1018 cx.emit(ThreadEvent::MessageDeleted(id));
1019 true
1020 }
1021
1022 /// Returns the representation of this [`Thread`] in a textual form.
1023 ///
1024 /// This is the representation we use when attaching a thread as context to another thread.
1025 pub fn text(&self) -> String {
1026 let mut text = String::new();
1027
1028 for message in &self.messages {
1029 text.push_str(match message.role {
1030 language_model::Role::User => "User:",
1031 language_model::Role::Assistant => "Agent:",
1032 language_model::Role::System => "System:",
1033 });
1034 text.push('\n');
1035
1036 for segment in &message.segments {
1037 match segment {
1038 MessageSegment::Text(content) => text.push_str(content),
1039 MessageSegment::Thinking { text: content, .. } => {
1040 text.push_str(&format!("<think>{}</think>", content))
1041 }
1042 MessageSegment::RedactedThinking(_) => {}
1043 }
1044 }
1045 text.push('\n');
1046 }
1047
1048 text
1049 }
1050
1051 /// Serializes this thread into a format for storage or telemetry.
1052 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1053 let initial_project_snapshot = self.initial_project_snapshot.clone();
1054 cx.spawn(async move |this, cx| {
1055 let initial_project_snapshot = initial_project_snapshot.await;
1056 this.read_with(cx, |this, cx| SerializedThread {
1057 version: SerializedThread::VERSION.to_string(),
1058 summary: this.summary_or_default(),
1059 updated_at: this.updated_at(),
1060 messages: this
1061 .messages()
1062 .map(|message| SerializedMessage {
1063 id: message.id,
1064 role: message.role,
1065 segments: message
1066 .segments
1067 .iter()
1068 .map(|segment| match segment {
1069 MessageSegment::Text(text) => {
1070 SerializedMessageSegment::Text { text: text.clone() }
1071 }
1072 MessageSegment::Thinking { text, signature } => {
1073 SerializedMessageSegment::Thinking {
1074 text: text.clone(),
1075 signature: signature.clone(),
1076 }
1077 }
1078 MessageSegment::RedactedThinking(data) => {
1079 SerializedMessageSegment::RedactedThinking {
1080 data: data.clone(),
1081 }
1082 }
1083 })
1084 .collect(),
1085 tool_uses: this
1086 .tool_uses_for_message(message.id, cx)
1087 .into_iter()
1088 .map(|tool_use| SerializedToolUse {
1089 id: tool_use.id,
1090 name: tool_use.name,
1091 input: tool_use.input,
1092 })
1093 .collect(),
1094 tool_results: this
1095 .tool_results_for_message(message.id)
1096 .into_iter()
1097 .map(|tool_result| SerializedToolResult {
1098 tool_use_id: tool_result.tool_use_id.clone(),
1099 is_error: tool_result.is_error,
1100 content: tool_result.content.clone(),
1101 output: tool_result.output.clone(),
1102 })
1103 .collect(),
1104 context: message.loaded_context.text.clone(),
1105 creases: message
1106 .creases
1107 .iter()
1108 .map(|crease| SerializedCrease {
1109 start: crease.range.start,
1110 end: crease.range.end,
1111 icon_path: crease.metadata.icon_path.clone(),
1112 label: crease.metadata.label.clone(),
1113 })
1114 .collect(),
1115 })
1116 .collect(),
1117 initial_project_snapshot,
1118 cumulative_token_usage: this.cumulative_token_usage,
1119 request_token_usage: this.request_token_usage.clone(),
1120 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1121 exceeded_window_error: this.exceeded_window_error.clone(),
1122 model: this
1123 .configured_model
1124 .as_ref()
1125 .map(|model| SerializedLanguageModel {
1126 provider: model.provider.id().0.to_string(),
1127 model: model.model.id().0.to_string(),
1128 }),
1129 profile: this
1130 .configured_profile
1131 .as_ref()
1132 .map(|profile| AgentProfileId(profile.name.clone().into())),
1133 completion_mode: Some(this.completion_mode),
1134 })
1135 })
1136 }
1137
1138 pub fn remaining_turns(&self) -> u32 {
1139 self.remaining_turns
1140 }
1141
1142 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1143 self.remaining_turns = remaining_turns;
1144 }
1145
1146 pub fn send_to_model(
1147 &mut self,
1148 model: Arc<dyn LanguageModel>,
1149 window: Option<AnyWindowHandle>,
1150 cx: &mut Context<Self>,
1151 ) {
1152 if self.remaining_turns == 0 {
1153 return;
1154 }
1155
1156 self.remaining_turns -= 1;
1157
1158 let request = self.to_completion_request(model.clone(), cx);
1159
1160 self.stream_completion(request, model, window, cx);
1161 }
1162
1163 pub fn used_tools_since_last_user_message(&self) -> bool {
1164 for message in self.messages.iter().rev() {
1165 if self.tool_use.message_has_tool_results(message.id) {
1166 return true;
1167 } else if message.role == Role::User {
1168 return false;
1169 }
1170 }
1171
1172 false
1173 }
1174
1175 pub fn to_completion_request(
1176 &self,
1177 model: Arc<dyn LanguageModel>,
1178 cx: &mut Context<Self>,
1179 ) -> LanguageModelRequest {
1180 let mut request = LanguageModelRequest {
1181 thread_id: Some(self.id.to_string()),
1182 prompt_id: Some(self.last_prompt_id.to_string()),
1183 mode: None,
1184 messages: vec![],
1185 tools: Vec::new(),
1186 tool_choice: None,
1187 stop: Vec::new(),
1188 temperature: AssistantSettings::temperature_for_model(&model, cx),
1189 };
1190
1191 let available_tools = self.available_tools(cx, model.clone());
1192 let available_tool_names = available_tools
1193 .iter()
1194 .map(|tool| tool.name.clone())
1195 .collect();
1196
1197 let model_context = &ModelContext {
1198 available_tools: available_tool_names,
1199 };
1200
1201 if let Some(project_context) = self.project_context.borrow().as_ref() {
1202 match self
1203 .prompt_builder
1204 .generate_assistant_system_prompt(project_context, model_context)
1205 {
1206 Err(err) => {
1207 let message = format!("{err:?}").into();
1208 log::error!("{message}");
1209 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1210 header: "Error generating system prompt".into(),
1211 message,
1212 }));
1213 }
1214 Ok(system_prompt) => {
1215 request.messages.push(LanguageModelRequestMessage {
1216 role: Role::System,
1217 content: vec![MessageContent::Text(system_prompt)],
1218 cache: true,
1219 });
1220 }
1221 }
1222 } else {
1223 let message = "Context for system prompt unexpectedly not ready.".into();
1224 log::error!("{message}");
1225 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1226 header: "Error generating system prompt".into(),
1227 message,
1228 }));
1229 }
1230
1231 let mut message_ix_to_cache = None;
1232 for message in &self.messages {
1233 let mut request_message = LanguageModelRequestMessage {
1234 role: message.role,
1235 content: Vec::new(),
1236 cache: false,
1237 };
1238
1239 message
1240 .loaded_context
1241 .add_to_request_message(&mut request_message);
1242
1243 for segment in &message.segments {
1244 match segment {
1245 MessageSegment::Text(text) => {
1246 if !text.is_empty() {
1247 request_message
1248 .content
1249 .push(MessageContent::Text(text.into()));
1250 }
1251 }
1252 MessageSegment::Thinking { text, signature } => {
1253 if !text.is_empty() {
1254 request_message.content.push(MessageContent::Thinking {
1255 text: text.into(),
1256 signature: signature.clone(),
1257 });
1258 }
1259 }
1260 MessageSegment::RedactedThinking(data) => {
1261 request_message
1262 .content
1263 .push(MessageContent::RedactedThinking(data.clone()));
1264 }
1265 };
1266 }
1267
1268 let mut cache_message = true;
1269 let mut tool_results_message = LanguageModelRequestMessage {
1270 role: Role::User,
1271 content: Vec::new(),
1272 cache: false,
1273 };
1274 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1275 if let Some(tool_result) = tool_result {
1276 request_message
1277 .content
1278 .push(MessageContent::ToolUse(tool_use.clone()));
1279 tool_results_message
1280 .content
1281 .push(MessageContent::ToolResult(LanguageModelToolResult {
1282 tool_use_id: tool_use.id.clone(),
1283 tool_name: tool_result.tool_name.clone(),
1284 is_error: tool_result.is_error,
1285 content: if tool_result.content.is_empty() {
1286 // Surprisingly, the API fails if we return an empty string here.
1287 // It thinks we are sending a tool use without a tool result.
1288 "<Tool returned an empty string>".into()
1289 } else {
1290 tool_result.content.clone()
1291 },
1292 output: None,
1293 }));
1294 } else {
1295 cache_message = false;
1296 log::debug!(
1297 "skipped tool use {:?} because it is still pending",
1298 tool_use
1299 );
1300 }
1301 }
1302
1303 if cache_message {
1304 message_ix_to_cache = Some(request.messages.len());
1305 }
1306 request.messages.push(request_message);
1307
1308 if !tool_results_message.content.is_empty() {
1309 if cache_message {
1310 message_ix_to_cache = Some(request.messages.len());
1311 }
1312 request.messages.push(tool_results_message);
1313 }
1314 }
1315
1316 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1317 if let Some(message_ix_to_cache) = message_ix_to_cache {
1318 request.messages[message_ix_to_cache].cache = true;
1319 }
1320
1321 self.attached_tracked_files_state(&mut request.messages, cx);
1322
1323 request.tools = available_tools;
1324 request.mode = if model.supports_max_mode() {
1325 Some(self.completion_mode.into())
1326 } else {
1327 Some(CompletionMode::Normal.into())
1328 };
1329
1330 request
1331 }
1332
1333 fn to_summarize_request(
1334 &self,
1335 model: &Arc<dyn LanguageModel>,
1336 added_user_message: String,
1337 cx: &App,
1338 ) -> LanguageModelRequest {
1339 let mut request = LanguageModelRequest {
1340 thread_id: None,
1341 prompt_id: None,
1342 mode: None,
1343 messages: vec![],
1344 tools: Vec::new(),
1345 tool_choice: None,
1346 stop: Vec::new(),
1347 temperature: AssistantSettings::temperature_for_model(model, cx),
1348 };
1349
1350 for message in &self.messages {
1351 let mut request_message = LanguageModelRequestMessage {
1352 role: message.role,
1353 content: Vec::new(),
1354 cache: false,
1355 };
1356
1357 for segment in &message.segments {
1358 match segment {
1359 MessageSegment::Text(text) => request_message
1360 .content
1361 .push(MessageContent::Text(text.clone())),
1362 MessageSegment::Thinking { .. } => {}
1363 MessageSegment::RedactedThinking(_) => {}
1364 }
1365 }
1366
1367 if request_message.content.is_empty() {
1368 continue;
1369 }
1370
1371 request.messages.push(request_message);
1372 }
1373
1374 request.messages.push(LanguageModelRequestMessage {
1375 role: Role::User,
1376 content: vec![MessageContent::Text(added_user_message)],
1377 cache: false,
1378 });
1379
1380 request
1381 }
1382
1383 fn attached_tracked_files_state(
1384 &self,
1385 messages: &mut Vec<LanguageModelRequestMessage>,
1386 cx: &App,
1387 ) {
1388 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1389
1390 let mut stale_message = String::new();
1391
1392 let action_log = self.action_log.read(cx);
1393
1394 for stale_file in action_log.stale_buffers(cx) {
1395 let Some(file) = stale_file.read(cx).file() else {
1396 continue;
1397 };
1398
1399 if stale_message.is_empty() {
1400 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1401 }
1402
1403 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1404 }
1405
1406 let mut content = Vec::with_capacity(2);
1407
1408 if !stale_message.is_empty() {
1409 content.push(stale_message.into());
1410 }
1411
1412 if !content.is_empty() {
1413 let context_message = LanguageModelRequestMessage {
1414 role: Role::User,
1415 content,
1416 cache: false,
1417 };
1418
1419 messages.push(context_message);
1420 }
1421 }
1422
1423 pub fn stream_completion(
1424 &mut self,
1425 request: LanguageModelRequest,
1426 model: Arc<dyn LanguageModel>,
1427 window: Option<AnyWindowHandle>,
1428 cx: &mut Context<Self>,
1429 ) {
1430 self.tool_use_limit_reached = false;
1431
1432 let pending_completion_id = post_inc(&mut self.completion_count);
1433 let mut request_callback_parameters = if self.request_callback.is_some() {
1434 Some((request.clone(), Vec::new()))
1435 } else {
1436 None
1437 };
1438 let prompt_id = self.last_prompt_id.clone();
1439 let tool_use_metadata = ToolUseMetadata {
1440 model: model.clone(),
1441 thread_id: self.id.clone(),
1442 prompt_id: prompt_id.clone(),
1443 };
1444
1445 self.last_received_chunk_at = Some(Instant::now());
1446
1447 let task = cx.spawn(async move |thread, cx| {
1448 let stream_completion_future = model.stream_completion(request, &cx);
1449 let initial_token_usage =
1450 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1451 let stream_completion = async {
1452 let mut events = stream_completion_future.await?;
1453
1454 let mut stop_reason = StopReason::EndTurn;
1455 let mut current_token_usage = TokenUsage::default();
1456
1457 thread
1458 .update(cx, |_thread, cx| {
1459 cx.emit(ThreadEvent::NewRequest);
1460 })
1461 .ok();
1462
1463 let mut request_assistant_message_id = None;
1464
1465 while let Some(event) = events.next().await {
1466 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1467 response_events
1468 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1469 }
1470
1471 thread.update(cx, |thread, cx| {
1472 let event = match event {
1473 Ok(event) => event,
1474 Err(LanguageModelCompletionError::BadInputJson {
1475 id,
1476 tool_name,
1477 raw_input: invalid_input_json,
1478 json_parse_error,
1479 }) => {
1480 thread.receive_invalid_tool_json(
1481 id,
1482 tool_name,
1483 invalid_input_json,
1484 json_parse_error,
1485 window,
1486 cx,
1487 );
1488 return Ok(());
1489 }
1490 Err(LanguageModelCompletionError::Other(error)) => {
1491 return Err(error);
1492 }
1493 };
1494
1495 match event {
1496 LanguageModelCompletionEvent::StartMessage { .. } => {
1497 request_assistant_message_id =
1498 Some(thread.insert_assistant_message(
1499 vec![MessageSegment::Text(String::new())],
1500 cx,
1501 ));
1502 }
1503 LanguageModelCompletionEvent::Stop(reason) => {
1504 stop_reason = reason;
1505 }
1506 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1507 thread.update_token_usage_at_last_message(token_usage);
1508 thread.cumulative_token_usage = thread.cumulative_token_usage
1509 + token_usage
1510 - current_token_usage;
1511 current_token_usage = token_usage;
1512 }
1513 LanguageModelCompletionEvent::Text(chunk) => {
1514 thread.received_chunk();
1515
1516 cx.emit(ThreadEvent::ReceivedTextChunk);
1517 if let Some(last_message) = thread.messages.last_mut() {
1518 if last_message.role == Role::Assistant
1519 && !thread.tool_use.has_tool_results(last_message.id)
1520 {
1521 last_message.push_text(&chunk);
1522 cx.emit(ThreadEvent::StreamedAssistantText(
1523 last_message.id,
1524 chunk,
1525 ));
1526 } else {
1527 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1528 // of a new Assistant response.
1529 //
1530 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1531 // will result in duplicating the text of the chunk in the rendered Markdown.
1532 request_assistant_message_id =
1533 Some(thread.insert_assistant_message(
1534 vec![MessageSegment::Text(chunk.to_string())],
1535 cx,
1536 ));
1537 };
1538 }
1539 }
1540 LanguageModelCompletionEvent::Thinking {
1541 text: chunk,
1542 signature,
1543 } => {
1544 thread.received_chunk();
1545
1546 if let Some(last_message) = thread.messages.last_mut() {
1547 if last_message.role == Role::Assistant
1548 && !thread.tool_use.has_tool_results(last_message.id)
1549 {
1550 last_message.push_thinking(&chunk, signature);
1551 cx.emit(ThreadEvent::StreamedAssistantThinking(
1552 last_message.id,
1553 chunk,
1554 ));
1555 } else {
1556 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1557 // of a new Assistant response.
1558 //
1559 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1560 // will result in duplicating the text of the chunk in the rendered Markdown.
1561 request_assistant_message_id =
1562 Some(thread.insert_assistant_message(
1563 vec![MessageSegment::Thinking {
1564 text: chunk.to_string(),
1565 signature,
1566 }],
1567 cx,
1568 ));
1569 };
1570 }
1571 }
1572 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1573 let last_assistant_message_id = request_assistant_message_id
1574 .unwrap_or_else(|| {
1575 let new_assistant_message_id =
1576 thread.insert_assistant_message(vec![], cx);
1577 request_assistant_message_id =
1578 Some(new_assistant_message_id);
1579 new_assistant_message_id
1580 });
1581
1582 let tool_use_id = tool_use.id.clone();
1583 let streamed_input = if tool_use.is_input_complete {
1584 None
1585 } else {
1586 Some((&tool_use.input).clone())
1587 };
1588
1589 let ui_text = thread.tool_use.request_tool_use(
1590 last_assistant_message_id,
1591 tool_use,
1592 tool_use_metadata.clone(),
1593 cx,
1594 );
1595
1596 if let Some(input) = streamed_input {
1597 cx.emit(ThreadEvent::StreamedToolUse {
1598 tool_use_id,
1599 ui_text,
1600 input,
1601 });
1602 }
1603 }
1604 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1605 if let Some(completion) = thread
1606 .pending_completions
1607 .iter_mut()
1608 .find(|completion| completion.id == pending_completion_id)
1609 {
1610 match status_update {
1611 CompletionRequestStatus::Queued {
1612 position,
1613 } => {
1614 completion.queue_state = QueueState::Queued { position };
1615 }
1616 CompletionRequestStatus::Started => {
1617 completion.queue_state = QueueState::Started;
1618 }
1619 CompletionRequestStatus::Failed {
1620 code, message, request_id
1621 } => {
1622 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1623 }
1624 CompletionRequestStatus::UsageUpdated {
1625 amount, limit
1626 } => {
1627 let usage = RequestUsage { limit, amount: amount as i32 };
1628
1629 thread.last_usage = Some(usage);
1630 }
1631 CompletionRequestStatus::ToolUseLimitReached => {
1632 thread.tool_use_limit_reached = true;
1633 }
1634 }
1635 }
1636 }
1637 }
1638
1639 thread.touch_updated_at();
1640 cx.emit(ThreadEvent::StreamedCompletion);
1641 cx.notify();
1642
1643 thread.auto_capture_telemetry(cx);
1644 Ok(())
1645 })??;
1646
1647 smol::future::yield_now().await;
1648 }
1649
1650 thread.update(cx, |thread, cx| {
1651 thread.last_received_chunk_at = None;
1652 thread
1653 .pending_completions
1654 .retain(|completion| completion.id != pending_completion_id);
1655
1656 // If there is a response without tool use, summarize the message. Otherwise,
1657 // allow two tool uses before summarizing.
1658 if thread.summary.is_none()
1659 && thread.messages.len() >= 2
1660 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1661 {
1662 thread.summarize(cx);
1663 }
1664 })?;
1665
1666 anyhow::Ok(stop_reason)
1667 };
1668
1669 let result = stream_completion.await;
1670
1671 thread
1672 .update(cx, |thread, cx| {
1673 thread.finalize_pending_checkpoint(cx);
1674 match result.as_ref() {
1675 Ok(stop_reason) => match stop_reason {
1676 StopReason::ToolUse => {
1677 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1678 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1679 }
1680 StopReason::EndTurn | StopReason::MaxTokens => {
1681 thread.project.update(cx, |project, cx| {
1682 project.set_agent_location(None, cx);
1683 });
1684 }
1685 },
1686 Err(error) => {
1687 thread.project.update(cx, |project, cx| {
1688 project.set_agent_location(None, cx);
1689 });
1690
1691 if error.is::<PaymentRequiredError>() {
1692 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1693 } else if error.is::<MaxMonthlySpendReachedError>() {
1694 cx.emit(ThreadEvent::ShowError(
1695 ThreadError::MaxMonthlySpendReached,
1696 ));
1697 } else if let Some(error) =
1698 error.downcast_ref::<ModelRequestLimitReachedError>()
1699 {
1700 cx.emit(ThreadEvent::ShowError(
1701 ThreadError::ModelRequestLimitReached { plan: error.plan },
1702 ));
1703 } else if let Some(known_error) =
1704 error.downcast_ref::<LanguageModelKnownError>()
1705 {
1706 match known_error {
1707 LanguageModelKnownError::ContextWindowLimitExceeded {
1708 tokens,
1709 } => {
1710 thread.exceeded_window_error = Some(ExceededWindowError {
1711 model_id: model.id(),
1712 token_count: *tokens,
1713 });
1714 cx.notify();
1715 }
1716 }
1717 } else {
1718 let error_message = error
1719 .chain()
1720 .map(|err| err.to_string())
1721 .collect::<Vec<_>>()
1722 .join("\n");
1723 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1724 header: "Error interacting with language model".into(),
1725 message: SharedString::from(error_message.clone()),
1726 }));
1727 }
1728
1729 thread.cancel_last_completion(window, cx);
1730 }
1731 }
1732 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1733
1734 if let Some((request_callback, (request, response_events))) = thread
1735 .request_callback
1736 .as_mut()
1737 .zip(request_callback_parameters.as_ref())
1738 {
1739 request_callback(request, response_events);
1740 }
1741
1742 thread.auto_capture_telemetry(cx);
1743
1744 if let Ok(initial_usage) = initial_token_usage {
1745 let usage = thread.cumulative_token_usage - initial_usage;
1746
1747 telemetry::event!(
1748 "Assistant Thread Completion",
1749 thread_id = thread.id().to_string(),
1750 prompt_id = prompt_id,
1751 model = model.telemetry_id(),
1752 model_provider = model.provider_id().to_string(),
1753 input_tokens = usage.input_tokens,
1754 output_tokens = usage.output_tokens,
1755 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1756 cache_read_input_tokens = usage.cache_read_input_tokens,
1757 );
1758 }
1759 })
1760 .ok();
1761 });
1762
1763 self.pending_completions.push(PendingCompletion {
1764 id: pending_completion_id,
1765 queue_state: QueueState::Sending,
1766 _task: task,
1767 });
1768 }
1769
1770 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1771 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1772 return;
1773 };
1774
1775 if !model.provider.is_authenticated(cx) {
1776 return;
1777 }
1778
1779 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1780 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1781 If the conversation is about a specific subject, include it in the title. \
1782 Be descriptive. DO NOT speak in the first person.";
1783
1784 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1785
1786 self.pending_summary = cx.spawn(async move |this, cx| {
1787 async move {
1788 let mut messages = model.model.stream_completion(request, &cx).await?;
1789
1790 let mut new_summary = String::new();
1791 while let Some(event) = messages.next().await {
1792 let event = event?;
1793 let text = match event {
1794 LanguageModelCompletionEvent::Text(text) => text,
1795 LanguageModelCompletionEvent::StatusUpdate(
1796 CompletionRequestStatus::UsageUpdated { amount, limit },
1797 ) => {
1798 this.update(cx, |thread, _cx| {
1799 thread.last_usage = Some(RequestUsage {
1800 limit,
1801 amount: amount as i32,
1802 });
1803 })?;
1804 continue;
1805 }
1806 _ => continue,
1807 };
1808
1809 let mut lines = text.lines();
1810 new_summary.extend(lines.next());
1811
1812 // Stop if the LLM generated multiple lines.
1813 if lines.next().is_some() {
1814 break;
1815 }
1816 }
1817
1818 this.update(cx, |this, cx| {
1819 if !new_summary.is_empty() {
1820 this.summary = Some(new_summary.into());
1821 }
1822
1823 cx.emit(ThreadEvent::SummaryGenerated);
1824 })?;
1825
1826 anyhow::Ok(())
1827 }
1828 .log_err()
1829 .await
1830 });
1831 }
1832
1833 pub fn start_generating_detailed_summary_if_needed(
1834 &mut self,
1835 thread_store: WeakEntity<ThreadStore>,
1836 cx: &mut Context<Self>,
1837 ) {
1838 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1839 return;
1840 };
1841
1842 match &*self.detailed_summary_rx.borrow() {
1843 DetailedSummaryState::Generating { message_id, .. }
1844 | DetailedSummaryState::Generated { message_id, .. }
1845 if *message_id == last_message_id =>
1846 {
1847 // Already up-to-date
1848 return;
1849 }
1850 _ => {}
1851 }
1852
1853 let Some(ConfiguredModel { model, provider }) =
1854 LanguageModelRegistry::read_global(cx).thread_summary_model()
1855 else {
1856 return;
1857 };
1858
1859 if !provider.is_authenticated(cx) {
1860 return;
1861 }
1862
1863 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1864 1. A brief overview of what was discussed\n\
1865 2. Key facts or information discovered\n\
1866 3. Outcomes or conclusions reached\n\
1867 4. Any action items or next steps if any\n\
1868 Format it in Markdown with headings and bullet points.";
1869
1870 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1871
1872 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1873 message_id: last_message_id,
1874 };
1875
1876 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1877 // be better to allow the old task to complete, but this would require logic for choosing
1878 // which result to prefer (the old task could complete after the new one, resulting in a
1879 // stale summary).
1880 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1881 let stream = model.stream_completion_text(request, &cx);
1882 let Some(mut messages) = stream.await.log_err() else {
1883 thread
1884 .update(cx, |thread, _cx| {
1885 *thread.detailed_summary_tx.borrow_mut() =
1886 DetailedSummaryState::NotGenerated;
1887 })
1888 .ok()?;
1889 return None;
1890 };
1891
1892 let mut new_detailed_summary = String::new();
1893
1894 while let Some(chunk) = messages.stream.next().await {
1895 if let Some(chunk) = chunk.log_err() {
1896 new_detailed_summary.push_str(&chunk);
1897 }
1898 }
1899
1900 thread
1901 .update(cx, |thread, _cx| {
1902 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1903 text: new_detailed_summary.into(),
1904 message_id: last_message_id,
1905 };
1906 })
1907 .ok()?;
1908
1909 // Save thread so its summary can be reused later
1910 if let Some(thread) = thread.upgrade() {
1911 if let Ok(Ok(save_task)) = cx.update(|cx| {
1912 thread_store
1913 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1914 }) {
1915 save_task.await.log_err();
1916 }
1917 }
1918
1919 Some(())
1920 });
1921 }
1922
1923 pub async fn wait_for_detailed_summary_or_text(
1924 this: &Entity<Self>,
1925 cx: &mut AsyncApp,
1926 ) -> Option<SharedString> {
1927 let mut detailed_summary_rx = this
1928 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1929 .ok()?;
1930 loop {
1931 match detailed_summary_rx.recv().await? {
1932 DetailedSummaryState::Generating { .. } => {}
1933 DetailedSummaryState::NotGenerated => {
1934 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1935 }
1936 DetailedSummaryState::Generated { text, .. } => return Some(text),
1937 }
1938 }
1939 }
1940
1941 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1942 self.detailed_summary_rx
1943 .borrow()
1944 .text()
1945 .unwrap_or_else(|| self.text().into())
1946 }
1947
1948 pub fn is_generating_detailed_summary(&self) -> bool {
1949 matches!(
1950 &*self.detailed_summary_rx.borrow(),
1951 DetailedSummaryState::Generating { .. }
1952 )
1953 }
1954
1955 pub fn use_pending_tools(
1956 &mut self,
1957 window: Option<AnyWindowHandle>,
1958 cx: &mut Context<Self>,
1959 model: Arc<dyn LanguageModel>,
1960 ) -> Vec<PendingToolUse> {
1961 self.auto_capture_telemetry(cx);
1962 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1963 let pending_tool_uses = self
1964 .tool_use
1965 .pending_tool_uses()
1966 .into_iter()
1967 .filter(|tool_use| tool_use.status.is_idle())
1968 .cloned()
1969 .collect::<Vec<_>>();
1970
1971 for tool_use in pending_tool_uses.iter() {
1972 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1973 if tool.needs_confirmation(&tool_use.input, cx)
1974 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1975 {
1976 self.tool_use.confirm_tool_use(
1977 tool_use.id.clone(),
1978 tool_use.ui_text.clone(),
1979 tool_use.input.clone(),
1980 request.clone(),
1981 tool,
1982 );
1983 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1984 } else {
1985 self.run_tool(
1986 tool_use.id.clone(),
1987 tool_use.ui_text.clone(),
1988 tool_use.input.clone(),
1989 request.clone(),
1990 tool,
1991 model.clone(),
1992 window,
1993 cx,
1994 );
1995 }
1996 } else {
1997 self.handle_hallucinated_tool_use(
1998 tool_use.id.clone(),
1999 tool_use.name.clone(),
2000 window,
2001 cx,
2002 );
2003 }
2004 }
2005
2006 pending_tool_uses
2007 }
2008
2009 pub fn handle_hallucinated_tool_use(
2010 &mut self,
2011 tool_use_id: LanguageModelToolUseId,
2012 hallucinated_tool_name: Arc<str>,
2013 window: Option<AnyWindowHandle>,
2014 cx: &mut Context<Thread>,
2015 ) {
2016 let available_tools = self.tools.read(cx).enabled_tools(cx);
2017
2018 let tool_list = available_tools
2019 .iter()
2020 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2021 .collect::<Vec<_>>()
2022 .join("\n");
2023
2024 let error_message = format!(
2025 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2026 hallucinated_tool_name, tool_list
2027 );
2028
2029 let pending_tool_use = self.tool_use.insert_tool_output(
2030 tool_use_id.clone(),
2031 hallucinated_tool_name,
2032 Err(anyhow!("Missing tool call: {error_message}")),
2033 self.configured_model.as_ref(),
2034 );
2035
2036 cx.emit(ThreadEvent::MissingToolUse {
2037 tool_use_id: tool_use_id.clone(),
2038 ui_text: error_message.into(),
2039 });
2040
2041 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2042 }
2043
2044 pub fn receive_invalid_tool_json(
2045 &mut self,
2046 tool_use_id: LanguageModelToolUseId,
2047 tool_name: Arc<str>,
2048 invalid_json: Arc<str>,
2049 error: String,
2050 window: Option<AnyWindowHandle>,
2051 cx: &mut Context<Thread>,
2052 ) {
2053 log::error!("The model returned invalid input JSON: {invalid_json}");
2054
2055 let pending_tool_use = self.tool_use.insert_tool_output(
2056 tool_use_id.clone(),
2057 tool_name,
2058 Err(anyhow!("Error parsing input JSON: {error}")),
2059 self.configured_model.as_ref(),
2060 );
2061 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2062 pending_tool_use.ui_text.clone()
2063 } else {
2064 log::error!(
2065 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2066 );
2067 format!("Unknown tool {}", tool_use_id).into()
2068 };
2069
2070 cx.emit(ThreadEvent::InvalidToolInput {
2071 tool_use_id: tool_use_id.clone(),
2072 ui_text,
2073 invalid_input_json: invalid_json,
2074 });
2075
2076 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2077 }
2078
2079 pub fn run_tool(
2080 &mut self,
2081 tool_use_id: LanguageModelToolUseId,
2082 ui_text: impl Into<SharedString>,
2083 input: serde_json::Value,
2084 request: Arc<LanguageModelRequest>,
2085 tool: Arc<dyn Tool>,
2086 model: Arc<dyn LanguageModel>,
2087 window: Option<AnyWindowHandle>,
2088 cx: &mut Context<Thread>,
2089 ) {
2090 let task =
2091 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2092 self.tool_use
2093 .run_pending_tool(tool_use_id, ui_text.into(), task);
2094 }
2095
2096 fn spawn_tool_use(
2097 &mut self,
2098 tool_use_id: LanguageModelToolUseId,
2099 request: Arc<LanguageModelRequest>,
2100 input: serde_json::Value,
2101 tool: Arc<dyn Tool>,
2102 model: Arc<dyn LanguageModel>,
2103 window: Option<AnyWindowHandle>,
2104 cx: &mut Context<Thread>,
2105 ) -> Task<()> {
2106 let tool_name: Arc<str> = tool.name().into();
2107
2108 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2109 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2110 } else {
2111 tool.run(
2112 input,
2113 request,
2114 self.project.clone(),
2115 self.action_log.clone(),
2116 model,
2117 window,
2118 cx,
2119 )
2120 };
2121
2122 // Store the card separately if it exists
2123 if let Some(card) = tool_result.card.clone() {
2124 self.tool_use
2125 .insert_tool_result_card(tool_use_id.clone(), card);
2126 }
2127
2128 cx.spawn({
2129 async move |thread: WeakEntity<Thread>, cx| {
2130 let output = tool_result.output.await;
2131
2132 thread
2133 .update(cx, |thread, cx| {
2134 let pending_tool_use = thread.tool_use.insert_tool_output(
2135 tool_use_id.clone(),
2136 tool_name,
2137 output,
2138 thread.configured_model.as_ref(),
2139 );
2140 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2141 })
2142 .ok();
2143 }
2144 })
2145 }
2146
2147 fn tool_finished(
2148 &mut self,
2149 tool_use_id: LanguageModelToolUseId,
2150 pending_tool_use: Option<PendingToolUse>,
2151 canceled: bool,
2152 window: Option<AnyWindowHandle>,
2153 cx: &mut Context<Self>,
2154 ) {
2155 if self.all_tools_finished() {
2156 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2157 if !canceled {
2158 self.send_to_model(model.clone(), window, cx);
2159 }
2160 self.auto_capture_telemetry(cx);
2161 }
2162 }
2163
2164 cx.emit(ThreadEvent::ToolFinished {
2165 tool_use_id,
2166 pending_tool_use,
2167 });
2168 }
2169
2170 /// Cancels the last pending completion, if there are any pending.
2171 ///
2172 /// Returns whether a completion was canceled.
2173 pub fn cancel_last_completion(
2174 &mut self,
2175 window: Option<AnyWindowHandle>,
2176 cx: &mut Context<Self>,
2177 ) -> bool {
2178 let mut canceled = self.pending_completions.pop().is_some();
2179
2180 for pending_tool_use in self.tool_use.cancel_pending() {
2181 canceled = true;
2182 self.tool_finished(
2183 pending_tool_use.id.clone(),
2184 Some(pending_tool_use),
2185 true,
2186 window,
2187 cx,
2188 );
2189 }
2190
2191 self.finalize_pending_checkpoint(cx);
2192
2193 if canceled {
2194 cx.emit(ThreadEvent::CompletionCanceled);
2195 }
2196
2197 canceled
2198 }
2199
2200 /// Signals that any in-progress editing should be canceled.
2201 ///
2202 /// This method is used to notify listeners (like ActiveThread) that
2203 /// they should cancel any editing operations.
2204 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2205 cx.emit(ThreadEvent::CancelEditing);
2206 }
2207
2208 pub fn feedback(&self) -> Option<ThreadFeedback> {
2209 self.feedback
2210 }
2211
2212 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2213 self.message_feedback.get(&message_id).copied()
2214 }
2215
2216 pub fn report_message_feedback(
2217 &mut self,
2218 message_id: MessageId,
2219 feedback: ThreadFeedback,
2220 cx: &mut Context<Self>,
2221 ) -> Task<Result<()>> {
2222 if self.message_feedback.get(&message_id) == Some(&feedback) {
2223 return Task::ready(Ok(()));
2224 }
2225
2226 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2227 let serialized_thread = self.serialize(cx);
2228 let thread_id = self.id().clone();
2229 let client = self.project.read(cx).client();
2230
2231 let enabled_tool_names: Vec<String> = self
2232 .tools()
2233 .read(cx)
2234 .enabled_tools(cx)
2235 .iter()
2236 .map(|tool| tool.name().to_string())
2237 .collect();
2238
2239 self.message_feedback.insert(message_id, feedback);
2240
2241 cx.notify();
2242
2243 let message_content = self
2244 .message(message_id)
2245 .map(|msg| msg.to_string())
2246 .unwrap_or_default();
2247
2248 cx.background_spawn(async move {
2249 let final_project_snapshot = final_project_snapshot.await;
2250 let serialized_thread = serialized_thread.await?;
2251 let thread_data =
2252 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2253
2254 let rating = match feedback {
2255 ThreadFeedback::Positive => "positive",
2256 ThreadFeedback::Negative => "negative",
2257 };
2258 telemetry::event!(
2259 "Assistant Thread Rated",
2260 rating,
2261 thread_id,
2262 enabled_tool_names,
2263 message_id = message_id.0,
2264 message_content,
2265 thread_data,
2266 final_project_snapshot
2267 );
2268 client.telemetry().flush_events().await;
2269
2270 Ok(())
2271 })
2272 }
2273
2274 pub fn report_feedback(
2275 &mut self,
2276 feedback: ThreadFeedback,
2277 cx: &mut Context<Self>,
2278 ) -> Task<Result<()>> {
2279 let last_assistant_message_id = self
2280 .messages
2281 .iter()
2282 .rev()
2283 .find(|msg| msg.role == Role::Assistant)
2284 .map(|msg| msg.id);
2285
2286 if let Some(message_id) = last_assistant_message_id {
2287 self.report_message_feedback(message_id, feedback, cx)
2288 } else {
2289 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2290 let serialized_thread = self.serialize(cx);
2291 let thread_id = self.id().clone();
2292 let client = self.project.read(cx).client();
2293 self.feedback = Some(feedback);
2294 cx.notify();
2295
2296 cx.background_spawn(async move {
2297 let final_project_snapshot = final_project_snapshot.await;
2298 let serialized_thread = serialized_thread.await?;
2299 let thread_data = serde_json::to_value(serialized_thread)
2300 .unwrap_or_else(|_| serde_json::Value::Null);
2301
2302 let rating = match feedback {
2303 ThreadFeedback::Positive => "positive",
2304 ThreadFeedback::Negative => "negative",
2305 };
2306 telemetry::event!(
2307 "Assistant Thread Rated",
2308 rating,
2309 thread_id,
2310 thread_data,
2311 final_project_snapshot
2312 );
2313 client.telemetry().flush_events().await;
2314
2315 Ok(())
2316 })
2317 }
2318 }
2319
2320 /// Create a snapshot of the current project state including git information and unsaved buffers.
2321 fn project_snapshot(
2322 project: Entity<Project>,
2323 cx: &mut Context<Self>,
2324 ) -> Task<Arc<ProjectSnapshot>> {
2325 let git_store = project.read(cx).git_store().clone();
2326 let worktree_snapshots: Vec<_> = project
2327 .read(cx)
2328 .visible_worktrees(cx)
2329 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2330 .collect();
2331
2332 cx.spawn(async move |_, cx| {
2333 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2334
2335 let mut unsaved_buffers = Vec::new();
2336 cx.update(|app_cx| {
2337 let buffer_store = project.read(app_cx).buffer_store();
2338 for buffer_handle in buffer_store.read(app_cx).buffers() {
2339 let buffer = buffer_handle.read(app_cx);
2340 if buffer.is_dirty() {
2341 if let Some(file) = buffer.file() {
2342 let path = file.path().to_string_lossy().to_string();
2343 unsaved_buffers.push(path);
2344 }
2345 }
2346 }
2347 })
2348 .ok();
2349
2350 Arc::new(ProjectSnapshot {
2351 worktree_snapshots,
2352 unsaved_buffer_paths: unsaved_buffers,
2353 timestamp: Utc::now(),
2354 })
2355 })
2356 }
2357
2358 fn worktree_snapshot(
2359 worktree: Entity<project::Worktree>,
2360 git_store: Entity<GitStore>,
2361 cx: &App,
2362 ) -> Task<WorktreeSnapshot> {
2363 cx.spawn(async move |cx| {
2364 // Get worktree path and snapshot
2365 let worktree_info = cx.update(|app_cx| {
2366 let worktree = worktree.read(app_cx);
2367 let path = worktree.abs_path().to_string_lossy().to_string();
2368 let snapshot = worktree.snapshot();
2369 (path, snapshot)
2370 });
2371
2372 let Ok((worktree_path, _snapshot)) = worktree_info else {
2373 return WorktreeSnapshot {
2374 worktree_path: String::new(),
2375 git_state: None,
2376 };
2377 };
2378
2379 let git_state = git_store
2380 .update(cx, |git_store, cx| {
2381 git_store
2382 .repositories()
2383 .values()
2384 .find(|repo| {
2385 repo.read(cx)
2386 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2387 .is_some()
2388 })
2389 .cloned()
2390 })
2391 .ok()
2392 .flatten()
2393 .map(|repo| {
2394 repo.update(cx, |repo, _| {
2395 let current_branch =
2396 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2397 repo.send_job(None, |state, _| async move {
2398 let RepositoryState::Local { backend, .. } = state else {
2399 return GitState {
2400 remote_url: None,
2401 head_sha: None,
2402 current_branch,
2403 diff: None,
2404 };
2405 };
2406
2407 let remote_url = backend.remote_url("origin");
2408 let head_sha = backend.head_sha().await;
2409 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2410
2411 GitState {
2412 remote_url,
2413 head_sha,
2414 current_branch,
2415 diff,
2416 }
2417 })
2418 })
2419 });
2420
2421 let git_state = match git_state {
2422 Some(git_state) => match git_state.ok() {
2423 Some(git_state) => git_state.await.ok(),
2424 None => None,
2425 },
2426 None => None,
2427 };
2428
2429 WorktreeSnapshot {
2430 worktree_path,
2431 git_state,
2432 }
2433 })
2434 }
2435
2436 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2437 let mut markdown = Vec::new();
2438
2439 if let Some(summary) = self.summary() {
2440 writeln!(markdown, "# {summary}\n")?;
2441 };
2442
2443 for message in self.messages() {
2444 writeln!(
2445 markdown,
2446 "## {role}\n",
2447 role = match message.role {
2448 Role::User => "User",
2449 Role::Assistant => "Agent",
2450 Role::System => "System",
2451 }
2452 )?;
2453
2454 if !message.loaded_context.text.is_empty() {
2455 writeln!(markdown, "{}", message.loaded_context.text)?;
2456 }
2457
2458 if !message.loaded_context.images.is_empty() {
2459 writeln!(
2460 markdown,
2461 "\n{} images attached as context.\n",
2462 message.loaded_context.images.len()
2463 )?;
2464 }
2465
2466 for segment in &message.segments {
2467 match segment {
2468 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2469 MessageSegment::Thinking { text, .. } => {
2470 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2471 }
2472 MessageSegment::RedactedThinking(_) => {}
2473 }
2474 }
2475
2476 for tool_use in self.tool_uses_for_message(message.id, cx) {
2477 writeln!(
2478 markdown,
2479 "**Use Tool: {} ({})**",
2480 tool_use.name, tool_use.id
2481 )?;
2482 writeln!(markdown, "```json")?;
2483 writeln!(
2484 markdown,
2485 "{}",
2486 serde_json::to_string_pretty(&tool_use.input)?
2487 )?;
2488 writeln!(markdown, "```")?;
2489 }
2490
2491 for tool_result in self.tool_results_for_message(message.id) {
2492 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2493 if tool_result.is_error {
2494 write!(markdown, " (Error)")?;
2495 }
2496
2497 writeln!(markdown, "**\n")?;
2498 writeln!(markdown, "{}", tool_result.content)?;
2499 }
2500 }
2501
2502 Ok(String::from_utf8_lossy(&markdown).to_string())
2503 }
2504
2505 pub fn keep_edits_in_range(
2506 &mut self,
2507 buffer: Entity<language::Buffer>,
2508 buffer_range: Range<language::Anchor>,
2509 cx: &mut Context<Self>,
2510 ) {
2511 self.action_log.update(cx, |action_log, cx| {
2512 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2513 });
2514 }
2515
2516 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2517 self.action_log
2518 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2519 }
2520
2521 pub fn reject_edits_in_ranges(
2522 &mut self,
2523 buffer: Entity<language::Buffer>,
2524 buffer_ranges: Vec<Range<language::Anchor>>,
2525 cx: &mut Context<Self>,
2526 ) -> Task<Result<()>> {
2527 self.action_log.update(cx, |action_log, cx| {
2528 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2529 })
2530 }
2531
2532 pub fn action_log(&self) -> &Entity<ActionLog> {
2533 &self.action_log
2534 }
2535
2536 pub fn project(&self) -> &Entity<Project> {
2537 &self.project
2538 }
2539
2540 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2541 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2542 return;
2543 }
2544
2545 let now = Instant::now();
2546 if let Some(last) = self.last_auto_capture_at {
2547 if now.duration_since(last).as_secs() < 10 {
2548 return;
2549 }
2550 }
2551
2552 self.last_auto_capture_at = Some(now);
2553
2554 let thread_id = self.id().clone();
2555 let github_login = self
2556 .project
2557 .read(cx)
2558 .user_store()
2559 .read(cx)
2560 .current_user()
2561 .map(|user| user.github_login.clone());
2562 let client = self.project.read(cx).client().clone();
2563 let serialize_task = self.serialize(cx);
2564
2565 cx.background_executor()
2566 .spawn(async move {
2567 if let Ok(serialized_thread) = serialize_task.await {
2568 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2569 telemetry::event!(
2570 "Agent Thread Auto-Captured",
2571 thread_id = thread_id.to_string(),
2572 thread_data = thread_data,
2573 auto_capture_reason = "tracked_user",
2574 github_login = github_login
2575 );
2576
2577 client.telemetry().flush_events().await;
2578 }
2579 }
2580 })
2581 .detach();
2582 }
2583
2584 pub fn cumulative_token_usage(&self) -> TokenUsage {
2585 self.cumulative_token_usage
2586 }
2587
2588 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2589 let Some(model) = self.configured_model.as_ref() else {
2590 return TotalTokenUsage::default();
2591 };
2592
2593 let max = model.model.max_token_count();
2594
2595 let index = self
2596 .messages
2597 .iter()
2598 .position(|msg| msg.id == message_id)
2599 .unwrap_or(0);
2600
2601 if index == 0 {
2602 return TotalTokenUsage { total: 0, max };
2603 }
2604
2605 let token_usage = &self
2606 .request_token_usage
2607 .get(index - 1)
2608 .cloned()
2609 .unwrap_or_default();
2610
2611 TotalTokenUsage {
2612 total: token_usage.total_tokens() as usize,
2613 max,
2614 }
2615 }
2616
2617 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2618 let model = self.configured_model.as_ref()?;
2619
2620 let max = model.model.max_token_count();
2621
2622 if let Some(exceeded_error) = &self.exceeded_window_error {
2623 if model.model.id() == exceeded_error.model_id {
2624 return Some(TotalTokenUsage {
2625 total: exceeded_error.token_count,
2626 max,
2627 });
2628 }
2629 }
2630
2631 let total = self
2632 .token_usage_at_last_message()
2633 .unwrap_or_default()
2634 .total_tokens() as usize;
2635
2636 Some(TotalTokenUsage { total, max })
2637 }
2638
2639 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2640 self.request_token_usage
2641 .get(self.messages.len().saturating_sub(1))
2642 .or_else(|| self.request_token_usage.last())
2643 .cloned()
2644 }
2645
2646 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2647 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2648 self.request_token_usage
2649 .resize(self.messages.len(), placeholder);
2650
2651 if let Some(last) = self.request_token_usage.last_mut() {
2652 *last = token_usage;
2653 }
2654 }
2655
2656 pub fn deny_tool_use(
2657 &mut self,
2658 tool_use_id: LanguageModelToolUseId,
2659 tool_name: Arc<str>,
2660 window: Option<AnyWindowHandle>,
2661 cx: &mut Context<Self>,
2662 ) {
2663 let err = Err(anyhow::anyhow!(
2664 "Permission to run tool action denied by user"
2665 ));
2666
2667 self.tool_use.insert_tool_output(
2668 tool_use_id.clone(),
2669 tool_name,
2670 err,
2671 self.configured_model.as_ref(),
2672 );
2673 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2674 }
2675}
2676
2677#[derive(Debug, Clone, Error)]
2678pub enum ThreadError {
2679 #[error("Payment required")]
2680 PaymentRequired,
2681 #[error("Max monthly spend reached")]
2682 MaxMonthlySpendReached,
2683 #[error("Model request limit reached")]
2684 ModelRequestLimitReached { plan: Plan },
2685 #[error("Message {header}: {message}")]
2686 Message {
2687 header: SharedString,
2688 message: SharedString,
2689 },
2690}
2691
2692#[derive(Debug, Clone)]
2693pub enum ThreadEvent {
2694 ShowError(ThreadError),
2695 StreamedCompletion,
2696 ReceivedTextChunk,
2697 NewRequest,
2698 StreamedAssistantText(MessageId, String),
2699 StreamedAssistantThinking(MessageId, String),
2700 StreamedToolUse {
2701 tool_use_id: LanguageModelToolUseId,
2702 ui_text: Arc<str>,
2703 input: serde_json::Value,
2704 },
2705 MissingToolUse {
2706 tool_use_id: LanguageModelToolUseId,
2707 ui_text: Arc<str>,
2708 },
2709 InvalidToolInput {
2710 tool_use_id: LanguageModelToolUseId,
2711 ui_text: Arc<str>,
2712 invalid_input_json: Arc<str>,
2713 },
2714 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2715 MessageAdded(MessageId),
2716 MessageEdited(MessageId),
2717 MessageDeleted(MessageId),
2718 SummaryGenerated,
2719 SummaryChanged,
2720 UsePendingTools {
2721 tool_uses: Vec<PendingToolUse>,
2722 },
2723 ToolFinished {
2724 #[allow(unused)]
2725 tool_use_id: LanguageModelToolUseId,
2726 /// The pending tool use that corresponds to this tool.
2727 pending_tool_use: Option<PendingToolUse>,
2728 },
2729 CheckpointChanged,
2730 ToolConfirmationNeeded,
2731 CancelEditing,
2732 CompletionCanceled,
2733}
2734
2735impl EventEmitter<ThreadEvent> for Thread {}
2736
2737struct PendingCompletion {
2738 id: usize,
2739 queue_state: QueueState,
2740 _task: Task<()>,
2741}
2742
2743#[cfg(test)]
2744mod tests {
2745 use super::*;
2746 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2747 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2748 use assistant_tool::ToolRegistry;
2749 use editor::EditorSettings;
2750 use gpui::TestAppContext;
2751 use language_model::fake_provider::FakeLanguageModel;
2752 use project::{FakeFs, Project};
2753 use prompt_store::PromptBuilder;
2754 use serde_json::json;
2755 use settings::{Settings, SettingsStore};
2756 use std::sync::Arc;
2757 use theme::ThemeSettings;
2758 use util::path;
2759 use workspace::Workspace;
2760
2761 #[gpui::test]
2762 async fn test_message_with_context(cx: &mut TestAppContext) {
2763 init_test_settings(cx);
2764
2765 let project = create_test_project(
2766 cx,
2767 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2768 )
2769 .await;
2770
2771 let (_workspace, _thread_store, thread, context_store, model) =
2772 setup_test_environment(cx, project.clone()).await;
2773
2774 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2775 .await
2776 .unwrap();
2777
2778 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2779 let loaded_context = cx
2780 .update(|cx| load_context(vec![context], &project, &None, cx))
2781 .await;
2782
2783 // Insert user message with context
2784 let message_id = thread.update(cx, |thread, cx| {
2785 thread.insert_user_message(
2786 "Please explain this code",
2787 loaded_context,
2788 None,
2789 Vec::new(),
2790 cx,
2791 )
2792 });
2793
2794 // Check content and context in message object
2795 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2796
2797 // Use different path format strings based on platform for the test
2798 #[cfg(windows)]
2799 let path_part = r"test\code.rs";
2800 #[cfg(not(windows))]
2801 let path_part = "test/code.rs";
2802
2803 let expected_context = format!(
2804 r#"
2805<context>
2806The following items were attached by the user. They are up-to-date and don't need to be re-read.
2807
2808<files>
2809```rs {path_part}
2810fn main() {{
2811 println!("Hello, world!");
2812}}
2813```
2814</files>
2815</context>
2816"#
2817 );
2818
2819 assert_eq!(message.role, Role::User);
2820 assert_eq!(message.segments.len(), 1);
2821 assert_eq!(
2822 message.segments[0],
2823 MessageSegment::Text("Please explain this code".to_string())
2824 );
2825 assert_eq!(message.loaded_context.text, expected_context);
2826
2827 // Check message in request
2828 let request = thread.update(cx, |thread, cx| {
2829 thread.to_completion_request(model.clone(), cx)
2830 });
2831
2832 assert_eq!(request.messages.len(), 2);
2833 let expected_full_message = format!("{}Please explain this code", expected_context);
2834 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2835 }
2836
2837 #[gpui::test]
2838 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2839 init_test_settings(cx);
2840
2841 let project = create_test_project(
2842 cx,
2843 json!({
2844 "file1.rs": "fn function1() {}\n",
2845 "file2.rs": "fn function2() {}\n",
2846 "file3.rs": "fn function3() {}\n",
2847 "file4.rs": "fn function4() {}\n",
2848 }),
2849 )
2850 .await;
2851
2852 let (_, _thread_store, thread, context_store, model) =
2853 setup_test_environment(cx, project.clone()).await;
2854
2855 // First message with context 1
2856 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2857 .await
2858 .unwrap();
2859 let new_contexts = context_store.update(cx, |store, cx| {
2860 store.new_context_for_thread(thread.read(cx), None)
2861 });
2862 assert_eq!(new_contexts.len(), 1);
2863 let loaded_context = cx
2864 .update(|cx| load_context(new_contexts, &project, &None, cx))
2865 .await;
2866 let message1_id = thread.update(cx, |thread, cx| {
2867 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2868 });
2869
2870 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2871 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2872 .await
2873 .unwrap();
2874 let new_contexts = context_store.update(cx, |store, cx| {
2875 store.new_context_for_thread(thread.read(cx), None)
2876 });
2877 assert_eq!(new_contexts.len(), 1);
2878 let loaded_context = cx
2879 .update(|cx| load_context(new_contexts, &project, &None, cx))
2880 .await;
2881 let message2_id = thread.update(cx, |thread, cx| {
2882 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2883 });
2884
2885 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2886 //
2887 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2888 .await
2889 .unwrap();
2890 let new_contexts = context_store.update(cx, |store, cx| {
2891 store.new_context_for_thread(thread.read(cx), None)
2892 });
2893 assert_eq!(new_contexts.len(), 1);
2894 let loaded_context = cx
2895 .update(|cx| load_context(new_contexts, &project, &None, cx))
2896 .await;
2897 let message3_id = thread.update(cx, |thread, cx| {
2898 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2899 });
2900
2901 // Check what contexts are included in each message
2902 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2903 (
2904 thread.message(message1_id).unwrap().clone(),
2905 thread.message(message2_id).unwrap().clone(),
2906 thread.message(message3_id).unwrap().clone(),
2907 )
2908 });
2909
2910 // First message should include context 1
2911 assert!(message1.loaded_context.text.contains("file1.rs"));
2912
2913 // Second message should include only context 2 (not 1)
2914 assert!(!message2.loaded_context.text.contains("file1.rs"));
2915 assert!(message2.loaded_context.text.contains("file2.rs"));
2916
2917 // Third message should include only context 3 (not 1 or 2)
2918 assert!(!message3.loaded_context.text.contains("file1.rs"));
2919 assert!(!message3.loaded_context.text.contains("file2.rs"));
2920 assert!(message3.loaded_context.text.contains("file3.rs"));
2921
2922 // Check entire request to make sure all contexts are properly included
2923 let request = thread.update(cx, |thread, cx| {
2924 thread.to_completion_request(model.clone(), cx)
2925 });
2926
2927 // The request should contain all 3 messages
2928 assert_eq!(request.messages.len(), 4);
2929
2930 // Check that the contexts are properly formatted in each message
2931 assert!(request.messages[1].string_contents().contains("file1.rs"));
2932 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2933 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2934
2935 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2936 assert!(request.messages[2].string_contents().contains("file2.rs"));
2937 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2938
2939 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2940 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2941 assert!(request.messages[3].string_contents().contains("file3.rs"));
2942
2943 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2944 .await
2945 .unwrap();
2946 let new_contexts = context_store.update(cx, |store, cx| {
2947 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2948 });
2949 assert_eq!(new_contexts.len(), 3);
2950 let loaded_context = cx
2951 .update(|cx| load_context(new_contexts, &project, &None, cx))
2952 .await
2953 .loaded_context;
2954
2955 assert!(!loaded_context.text.contains("file1.rs"));
2956 assert!(loaded_context.text.contains("file2.rs"));
2957 assert!(loaded_context.text.contains("file3.rs"));
2958 assert!(loaded_context.text.contains("file4.rs"));
2959
2960 let new_contexts = context_store.update(cx, |store, cx| {
2961 // Remove file4.rs
2962 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2963 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2964 });
2965 assert_eq!(new_contexts.len(), 2);
2966 let loaded_context = cx
2967 .update(|cx| load_context(new_contexts, &project, &None, cx))
2968 .await
2969 .loaded_context;
2970
2971 assert!(!loaded_context.text.contains("file1.rs"));
2972 assert!(loaded_context.text.contains("file2.rs"));
2973 assert!(loaded_context.text.contains("file3.rs"));
2974 assert!(!loaded_context.text.contains("file4.rs"));
2975
2976 let new_contexts = context_store.update(cx, |store, cx| {
2977 // Remove file3.rs
2978 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2979 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2980 });
2981 assert_eq!(new_contexts.len(), 1);
2982 let loaded_context = cx
2983 .update(|cx| load_context(new_contexts, &project, &None, cx))
2984 .await
2985 .loaded_context;
2986
2987 assert!(!loaded_context.text.contains("file1.rs"));
2988 assert!(loaded_context.text.contains("file2.rs"));
2989 assert!(!loaded_context.text.contains("file3.rs"));
2990 assert!(!loaded_context.text.contains("file4.rs"));
2991 }
2992
2993 #[gpui::test]
2994 async fn test_message_without_files(cx: &mut TestAppContext) {
2995 init_test_settings(cx);
2996
2997 let project = create_test_project(
2998 cx,
2999 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3000 )
3001 .await;
3002
3003 let (_, _thread_store, thread, _context_store, model) =
3004 setup_test_environment(cx, project.clone()).await;
3005
3006 // Insert user message without any context (empty context vector)
3007 let message_id = thread.update(cx, |thread, cx| {
3008 thread.insert_user_message(
3009 "What is the best way to learn Rust?",
3010 ContextLoadResult::default(),
3011 None,
3012 Vec::new(),
3013 cx,
3014 )
3015 });
3016
3017 // Check content and context in message object
3018 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3019
3020 // Context should be empty when no files are included
3021 assert_eq!(message.role, Role::User);
3022 assert_eq!(message.segments.len(), 1);
3023 assert_eq!(
3024 message.segments[0],
3025 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3026 );
3027 assert_eq!(message.loaded_context.text, "");
3028
3029 // Check message in request
3030 let request = thread.update(cx, |thread, cx| {
3031 thread.to_completion_request(model.clone(), cx)
3032 });
3033
3034 assert_eq!(request.messages.len(), 2);
3035 assert_eq!(
3036 request.messages[1].string_contents(),
3037 "What is the best way to learn Rust?"
3038 );
3039
3040 // Add second message, also without context
3041 let message2_id = thread.update(cx, |thread, cx| {
3042 thread.insert_user_message(
3043 "Are there any good books?",
3044 ContextLoadResult::default(),
3045 None,
3046 Vec::new(),
3047 cx,
3048 )
3049 });
3050
3051 let message2 =
3052 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3053 assert_eq!(message2.loaded_context.text, "");
3054
3055 // Check that both messages appear in the request
3056 let request = thread.update(cx, |thread, cx| {
3057 thread.to_completion_request(model.clone(), cx)
3058 });
3059
3060 assert_eq!(request.messages.len(), 3);
3061 assert_eq!(
3062 request.messages[1].string_contents(),
3063 "What is the best way to learn Rust?"
3064 );
3065 assert_eq!(
3066 request.messages[2].string_contents(),
3067 "Are there any good books?"
3068 );
3069 }
3070
3071 #[gpui::test]
3072 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3073 init_test_settings(cx);
3074
3075 let project = create_test_project(
3076 cx,
3077 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3078 )
3079 .await;
3080
3081 let (_workspace, _thread_store, thread, context_store, model) =
3082 setup_test_environment(cx, project.clone()).await;
3083
3084 // Open buffer and add it to context
3085 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3086 .await
3087 .unwrap();
3088
3089 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3090 let loaded_context = cx
3091 .update(|cx| load_context(vec![context], &project, &None, cx))
3092 .await;
3093
3094 // Insert user message with the buffer as context
3095 thread.update(cx, |thread, cx| {
3096 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3097 });
3098
3099 // Create a request and check that it doesn't have a stale buffer warning yet
3100 let initial_request = thread.update(cx, |thread, cx| {
3101 thread.to_completion_request(model.clone(), cx)
3102 });
3103
3104 // Make sure we don't have a stale file warning yet
3105 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3106 msg.string_contents()
3107 .contains("These files changed since last read:")
3108 });
3109 assert!(
3110 !has_stale_warning,
3111 "Should not have stale buffer warning before buffer is modified"
3112 );
3113
3114 // Modify the buffer
3115 buffer.update(cx, |buffer, cx| {
3116 // Find a position at the end of line 1
3117 buffer.edit(
3118 [(1..1, "\n println!(\"Added a new line\");\n")],
3119 None,
3120 cx,
3121 );
3122 });
3123
3124 // Insert another user message without context
3125 thread.update(cx, |thread, cx| {
3126 thread.insert_user_message(
3127 "What does the code do now?",
3128 ContextLoadResult::default(),
3129 None,
3130 Vec::new(),
3131 cx,
3132 )
3133 });
3134
3135 // Create a new request and check for the stale buffer warning
3136 let new_request = thread.update(cx, |thread, cx| {
3137 thread.to_completion_request(model.clone(), cx)
3138 });
3139
3140 // We should have a stale file warning as the last message
3141 let last_message = new_request
3142 .messages
3143 .last()
3144 .expect("Request should have messages");
3145
3146 // The last message should be the stale buffer notification
3147 assert_eq!(last_message.role, Role::User);
3148
3149 // Check the exact content of the message
3150 let expected_content = "These files changed since last read:\n- code.rs\n";
3151 assert_eq!(
3152 last_message.string_contents(),
3153 expected_content,
3154 "Last message should be exactly the stale buffer notification"
3155 );
3156 }
3157
3158 #[gpui::test]
3159 async fn test_temperature_setting(cx: &mut TestAppContext) {
3160 init_test_settings(cx);
3161
3162 let project = create_test_project(
3163 cx,
3164 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3165 )
3166 .await;
3167
3168 let (_workspace, _thread_store, thread, _context_store, model) =
3169 setup_test_environment(cx, project.clone()).await;
3170
3171 // Both model and provider
3172 cx.update(|cx| {
3173 AssistantSettings::override_global(
3174 AssistantSettings {
3175 model_parameters: vec![LanguageModelParameters {
3176 provider: Some(model.provider_id().0.to_string().into()),
3177 model: Some(model.id().0.clone()),
3178 temperature: Some(0.66),
3179 }],
3180 ..AssistantSettings::get_global(cx).clone()
3181 },
3182 cx,
3183 );
3184 });
3185
3186 let request = thread.update(cx, |thread, cx| {
3187 thread.to_completion_request(model.clone(), cx)
3188 });
3189 assert_eq!(request.temperature, Some(0.66));
3190
3191 // Only model
3192 cx.update(|cx| {
3193 AssistantSettings::override_global(
3194 AssistantSettings {
3195 model_parameters: vec![LanguageModelParameters {
3196 provider: None,
3197 model: Some(model.id().0.clone()),
3198 temperature: Some(0.66),
3199 }],
3200 ..AssistantSettings::get_global(cx).clone()
3201 },
3202 cx,
3203 );
3204 });
3205
3206 let request = thread.update(cx, |thread, cx| {
3207 thread.to_completion_request(model.clone(), cx)
3208 });
3209 assert_eq!(request.temperature, Some(0.66));
3210
3211 // Only provider
3212 cx.update(|cx| {
3213 AssistantSettings::override_global(
3214 AssistantSettings {
3215 model_parameters: vec![LanguageModelParameters {
3216 provider: Some(model.provider_id().0.to_string().into()),
3217 model: None,
3218 temperature: Some(0.66),
3219 }],
3220 ..AssistantSettings::get_global(cx).clone()
3221 },
3222 cx,
3223 );
3224 });
3225
3226 let request = thread.update(cx, |thread, cx| {
3227 thread.to_completion_request(model.clone(), cx)
3228 });
3229 assert_eq!(request.temperature, Some(0.66));
3230
3231 // Same model name, different provider
3232 cx.update(|cx| {
3233 AssistantSettings::override_global(
3234 AssistantSettings {
3235 model_parameters: vec![LanguageModelParameters {
3236 provider: Some("anthropic".into()),
3237 model: Some(model.id().0.clone()),
3238 temperature: Some(0.66),
3239 }],
3240 ..AssistantSettings::get_global(cx).clone()
3241 },
3242 cx,
3243 );
3244 });
3245
3246 let request = thread.update(cx, |thread, cx| {
3247 thread.to_completion_request(model.clone(), cx)
3248 });
3249 assert_eq!(request.temperature, None);
3250 }
3251
3252 fn init_test_settings(cx: &mut TestAppContext) {
3253 cx.update(|cx| {
3254 let settings_store = SettingsStore::test(cx);
3255 cx.set_global(settings_store);
3256 language::init(cx);
3257 Project::init_settings(cx);
3258 AssistantSettings::register(cx);
3259 prompt_store::init(cx);
3260 thread_store::init(cx);
3261 workspace::init_settings(cx);
3262 language_model::init_settings(cx);
3263 ThemeSettings::register(cx);
3264 EditorSettings::register(cx);
3265 ToolRegistry::default_global(cx);
3266 });
3267 }
3268
3269 // Helper to create a test project with test files
3270 async fn create_test_project(
3271 cx: &mut TestAppContext,
3272 files: serde_json::Value,
3273 ) -> Entity<Project> {
3274 let fs = FakeFs::new(cx.executor());
3275 fs.insert_tree(path!("/test"), files).await;
3276 Project::test(fs, [path!("/test").as_ref()], cx).await
3277 }
3278
3279 async fn setup_test_environment(
3280 cx: &mut TestAppContext,
3281 project: Entity<Project>,
3282 ) -> (
3283 Entity<Workspace>,
3284 Entity<ThreadStore>,
3285 Entity<Thread>,
3286 Entity<ContextStore>,
3287 Arc<dyn LanguageModel>,
3288 ) {
3289 let (workspace, cx) =
3290 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3291
3292 let thread_store = cx
3293 .update(|_, cx| {
3294 ThreadStore::load(
3295 project.clone(),
3296 cx.new(|_| ToolWorkingSet::default()),
3297 None,
3298 Arc::new(PromptBuilder::new(None).unwrap()),
3299 cx,
3300 )
3301 })
3302 .await
3303 .unwrap();
3304
3305 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3306 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3307
3308 let model = FakeLanguageModel::default();
3309 let model: Arc<dyn LanguageModel> = Arc::new(model);
3310
3311 (workspace, thread_store, thread, context_store, model)
3312 }
3313
3314 async fn add_file_to_context(
3315 project: &Entity<Project>,
3316 context_store: &Entity<ContextStore>,
3317 path: &str,
3318 cx: &mut TestAppContext,
3319 ) -> Result<Entity<language::Buffer>> {
3320 let buffer_path = project
3321 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3322 .unwrap();
3323
3324 let buffer = project
3325 .update(cx, |project, cx| {
3326 project.open_buffer(buffer_path.clone(), cx)
3327 })
3328 .await
3329 .unwrap();
3330
3331 context_store.update(cx, |context_store, cx| {
3332 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3333 });
3334
3335 Ok(buffer)
3336 }
3337}