1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, TryFutureExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: Option<SharedString>,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct ExceededWindowError {
366 /// Model used when last message exceeded context window
367 model_id: LanguageModelId,
368 /// Token count including last message
369 token_count: usize,
370}
371
372impl Thread {
373 pub fn new(
374 project: Entity<Project>,
375 tools: Entity<ToolWorkingSet>,
376 prompt_builder: Arc<PromptBuilder>,
377 system_prompt: SharedProjectContext,
378 cx: &mut Context<Self>,
379 ) -> Self {
380 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
381 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
382
383 Self {
384 id: ThreadId::new(),
385 updated_at: Utc::now(),
386 summary: None,
387 pending_summary: Task::ready(None),
388 detailed_summary_task: Task::ready(None),
389 detailed_summary_tx,
390 detailed_summary_rx,
391 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
392 messages: Vec::new(),
393 next_message_id: MessageId(0),
394 last_prompt_id: PromptId::new(),
395 project_context: system_prompt,
396 checkpoints_by_message: HashMap::default(),
397 completion_count: 0,
398 pending_completions: Vec::new(),
399 project: project.clone(),
400 prompt_builder,
401 tools: tools.clone(),
402 last_restore_checkpoint: None,
403 pending_checkpoint: None,
404 tool_use: ToolUseState::new(tools.clone()),
405 action_log: cx.new(|_| ActionLog::new(project.clone())),
406 initial_project_snapshot: {
407 let project_snapshot = Self::project_snapshot(project, cx);
408 cx.foreground_executor()
409 .spawn(async move { Some(project_snapshot.await) })
410 .shared()
411 },
412 request_token_usage: Vec::new(),
413 cumulative_token_usage: TokenUsage::default(),
414 exceeded_window_error: None,
415 last_usage: None,
416 tool_use_limit_reached: false,
417 feedback: None,
418 message_feedback: HashMap::default(),
419 last_auto_capture_at: None,
420 last_received_chunk_at: None,
421 request_callback: None,
422 remaining_turns: u32::MAX,
423 configured_model,
424 }
425 }
426
427 pub fn deserialize(
428 id: ThreadId,
429 serialized: SerializedThread,
430 project: Entity<Project>,
431 tools: Entity<ToolWorkingSet>,
432 prompt_builder: Arc<PromptBuilder>,
433 project_context: SharedProjectContext,
434 window: &mut Window,
435 cx: &mut Context<Self>,
436 ) -> Self {
437 let next_message_id = MessageId(
438 serialized
439 .messages
440 .last()
441 .map(|message| message.id.0 + 1)
442 .unwrap_or(0),
443 );
444 let tool_use = ToolUseState::from_serialized_messages(
445 tools.clone(),
446 &serialized.messages,
447 project.clone(),
448 window,
449 cx,
450 );
451 let (detailed_summary_tx, detailed_summary_rx) =
452 postage::watch::channel_with(serialized.detailed_summary_state);
453
454 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
455 serialized
456 .model
457 .and_then(|model| {
458 let model = SelectedModel {
459 provider: model.provider.clone().into(),
460 model: model.model.clone().into(),
461 };
462 registry.select_model(&model, cx)
463 })
464 .or_else(|| registry.default_model())
465 });
466
467 let completion_mode = serialized
468 .completion_mode
469 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
470
471 Self {
472 id,
473 updated_at: serialized.updated_at,
474 summary: Some(serialized.summary),
475 pending_summary: Task::ready(None),
476 detailed_summary_task: Task::ready(None),
477 detailed_summary_tx,
478 detailed_summary_rx,
479 completion_mode,
480 messages: serialized
481 .messages
482 .into_iter()
483 .map(|message| Message {
484 id: message.id,
485 role: message.role,
486 segments: message
487 .segments
488 .into_iter()
489 .map(|segment| match segment {
490 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
491 SerializedMessageSegment::Thinking { text, signature } => {
492 MessageSegment::Thinking { text, signature }
493 }
494 SerializedMessageSegment::RedactedThinking { data } => {
495 MessageSegment::RedactedThinking(data)
496 }
497 })
498 .collect(),
499 loaded_context: LoadedContext {
500 contexts: Vec::new(),
501 text: message.context,
502 images: Vec::new(),
503 },
504 creases: message
505 .creases
506 .into_iter()
507 .map(|crease| MessageCrease {
508 range: crease.start..crease.end,
509 metadata: CreaseMetadata {
510 icon_path: crease.icon_path,
511 label: crease.label,
512 },
513 context: None,
514 })
515 .collect(),
516 })
517 .collect(),
518 next_message_id,
519 last_prompt_id: PromptId::new(),
520 project_context,
521 checkpoints_by_message: HashMap::default(),
522 completion_count: 0,
523 pending_completions: Vec::new(),
524 last_restore_checkpoint: None,
525 pending_checkpoint: None,
526 project: project.clone(),
527 prompt_builder,
528 tools,
529 tool_use,
530 action_log: cx.new(|_| ActionLog::new(project)),
531 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
532 request_token_usage: serialized.request_token_usage,
533 cumulative_token_usage: serialized.cumulative_token_usage,
534 exceeded_window_error: None,
535 last_usage: None,
536 tool_use_limit_reached: false,
537 feedback: None,
538 message_feedback: HashMap::default(),
539 last_auto_capture_at: None,
540 last_received_chunk_at: None,
541 request_callback: None,
542 remaining_turns: u32::MAX,
543 configured_model,
544 }
545 }
546
547 pub fn set_request_callback(
548 &mut self,
549 callback: impl 'static
550 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
551 ) {
552 self.request_callback = Some(Box::new(callback));
553 }
554
555 pub fn id(&self) -> &ThreadId {
556 &self.id
557 }
558
559 pub fn is_empty(&self) -> bool {
560 self.messages.is_empty()
561 }
562
563 pub fn updated_at(&self) -> DateTime<Utc> {
564 self.updated_at
565 }
566
567 pub fn touch_updated_at(&mut self) {
568 self.updated_at = Utc::now();
569 }
570
571 pub fn advance_prompt_id(&mut self) {
572 self.last_prompt_id = PromptId::new();
573 }
574
575 pub fn summary(&self) -> Option<SharedString> {
576 self.summary.clone()
577 }
578
579 pub fn project_context(&self) -> SharedProjectContext {
580 self.project_context.clone()
581 }
582
583 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
584 if self.configured_model.is_none() {
585 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
586 }
587 self.configured_model.clone()
588 }
589
590 pub fn configured_model(&self) -> Option<ConfiguredModel> {
591 self.configured_model.clone()
592 }
593
594 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
595 self.configured_model = model;
596 cx.notify();
597 }
598
599 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
600
601 pub fn summary_or_default(&self) -> SharedString {
602 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
603 }
604
605 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
606 let Some(current_summary) = &self.summary else {
607 // Don't allow setting summary until generated
608 return;
609 };
610
611 let mut new_summary = new_summary.into();
612
613 if new_summary.is_empty() {
614 new_summary = Self::DEFAULT_SUMMARY;
615 }
616
617 if current_summary != &new_summary {
618 self.summary = Some(new_summary);
619 cx.emit(ThreadEvent::SummaryChanged);
620 }
621 }
622
623 pub fn completion_mode(&self) -> CompletionMode {
624 self.completion_mode
625 }
626
627 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
628 self.completion_mode = mode;
629 }
630
631 pub fn message(&self, id: MessageId) -> Option<&Message> {
632 let index = self
633 .messages
634 .binary_search_by(|message| message.id.cmp(&id))
635 .ok()?;
636
637 self.messages.get(index)
638 }
639
640 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
641 self.messages.iter()
642 }
643
644 pub fn is_generating(&self) -> bool {
645 !self.pending_completions.is_empty() || !self.all_tools_finished()
646 }
647
648 /// Indicates whether streaming of language model events is stale.
649 /// When `is_generating()` is false, this method returns `None`.
650 pub fn is_generation_stale(&self) -> Option<bool> {
651 const STALE_THRESHOLD: u128 = 250;
652
653 self.last_received_chunk_at
654 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
655 }
656
657 fn received_chunk(&mut self) {
658 self.last_received_chunk_at = Some(Instant::now());
659 }
660
661 pub fn queue_state(&self) -> Option<QueueState> {
662 self.pending_completions
663 .first()
664 .map(|pending_completion| pending_completion.queue_state)
665 }
666
667 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
668 &self.tools
669 }
670
671 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
672 self.tool_use
673 .pending_tool_uses()
674 .into_iter()
675 .find(|tool_use| &tool_use.id == id)
676 }
677
678 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
679 self.tool_use
680 .pending_tool_uses()
681 .into_iter()
682 .filter(|tool_use| tool_use.status.needs_confirmation())
683 }
684
685 pub fn has_pending_tool_uses(&self) -> bool {
686 !self.tool_use.pending_tool_uses().is_empty()
687 }
688
689 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
690 self.checkpoints_by_message.get(&id).cloned()
691 }
692
693 pub fn restore_checkpoint(
694 &mut self,
695 checkpoint: ThreadCheckpoint,
696 cx: &mut Context<Self>,
697 ) -> Task<Result<()>> {
698 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
699 message_id: checkpoint.message_id,
700 });
701 cx.emit(ThreadEvent::CheckpointChanged);
702 cx.notify();
703
704 let git_store = self.project().read(cx).git_store().clone();
705 let restore = git_store.update(cx, |git_store, cx| {
706 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
707 });
708
709 cx.spawn(async move |this, cx| {
710 let result = restore.await;
711 this.update(cx, |this, cx| {
712 if let Err(err) = result.as_ref() {
713 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
714 message_id: checkpoint.message_id,
715 error: err.to_string(),
716 });
717 } else {
718 this.truncate(checkpoint.message_id, cx);
719 this.last_restore_checkpoint = None;
720 }
721 this.pending_checkpoint = None;
722 cx.emit(ThreadEvent::CheckpointChanged);
723 cx.notify();
724 })?;
725 result
726 })
727 }
728
729 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
730 let pending_checkpoint = if self.is_generating() {
731 return;
732 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
733 checkpoint
734 } else {
735 return;
736 };
737
738 let git_store = self.project.read(cx).git_store().clone();
739 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
740 cx.spawn(async move |this, cx| match final_checkpoint.await {
741 Ok(final_checkpoint) => {
742 let equal = git_store
743 .update(cx, |store, cx| {
744 store.compare_checkpoints(
745 pending_checkpoint.git_checkpoint.clone(),
746 final_checkpoint.clone(),
747 cx,
748 )
749 })?
750 .await
751 .unwrap_or(false);
752
753 if !equal {
754 this.update(cx, |this, cx| {
755 this.insert_checkpoint(pending_checkpoint, cx)
756 })?;
757 }
758
759 Ok(())
760 }
761 Err(_) => this.update(cx, |this, cx| {
762 this.insert_checkpoint(pending_checkpoint, cx)
763 }),
764 })
765 .detach();
766 }
767
768 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
769 self.checkpoints_by_message
770 .insert(checkpoint.message_id, checkpoint);
771 cx.emit(ThreadEvent::CheckpointChanged);
772 cx.notify();
773 }
774
775 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
776 self.last_restore_checkpoint.as_ref()
777 }
778
779 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
780 let Some(message_ix) = self
781 .messages
782 .iter()
783 .rposition(|message| message.id == message_id)
784 else {
785 return;
786 };
787 for deleted_message in self.messages.drain(message_ix..) {
788 self.checkpoints_by_message.remove(&deleted_message.id);
789 }
790 cx.notify();
791 }
792
793 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
794 self.messages
795 .iter()
796 .find(|message| message.id == id)
797 .into_iter()
798 .flat_map(|message| message.loaded_context.contexts.iter())
799 }
800
801 pub fn is_turn_end(&self, ix: usize) -> bool {
802 if self.messages.is_empty() {
803 return false;
804 }
805
806 if !self.is_generating() && ix == self.messages.len() - 1 {
807 return true;
808 }
809
810 let Some(message) = self.messages.get(ix) else {
811 return false;
812 };
813
814 if message.role != Role::Assistant {
815 return false;
816 }
817
818 self.messages
819 .get(ix + 1)
820 .and_then(|message| {
821 self.message(message.id)
822 .map(|next_message| next_message.role == Role::User)
823 })
824 .unwrap_or(false)
825 }
826
827 pub fn last_usage(&self) -> Option<RequestUsage> {
828 self.last_usage
829 }
830
831 pub fn tool_use_limit_reached(&self) -> bool {
832 self.tool_use_limit_reached
833 }
834
835 /// Returns whether all of the tool uses have finished running.
836 pub fn all_tools_finished(&self) -> bool {
837 // If the only pending tool uses left are the ones with errors, then
838 // that means that we've finished running all of the pending tools.
839 self.tool_use
840 .pending_tool_uses()
841 .iter()
842 .all(|tool_use| tool_use.status.is_error())
843 }
844
845 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
846 self.tool_use.tool_uses_for_message(id, cx)
847 }
848
849 pub fn tool_results_for_message(
850 &self,
851 assistant_message_id: MessageId,
852 ) -> Vec<&LanguageModelToolResult> {
853 self.tool_use.tool_results_for_message(assistant_message_id)
854 }
855
856 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
857 self.tool_use.tool_result(id)
858 }
859
860 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
861 Some(&self.tool_use.tool_result(id)?.content)
862 }
863
864 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
865 self.tool_use.tool_result_card(id).cloned()
866 }
867
868 /// Return tools that are both enabled and supported by the model
869 pub fn available_tools(
870 &self,
871 cx: &App,
872 model: Arc<dyn LanguageModel>,
873 ) -> Vec<LanguageModelRequestTool> {
874 if model.supports_tools() {
875 self.tools()
876 .read(cx)
877 .enabled_tools(cx)
878 .into_iter()
879 .filter_map(|tool| {
880 // Skip tools that cannot be supported
881 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
882 Some(LanguageModelRequestTool {
883 name: tool.name(),
884 description: tool.description(),
885 input_schema,
886 })
887 })
888 .collect()
889 } else {
890 Vec::default()
891 }
892 }
893
894 pub fn insert_user_message(
895 &mut self,
896 text: impl Into<String>,
897 loaded_context: ContextLoadResult,
898 git_checkpoint: Option<GitStoreCheckpoint>,
899 creases: Vec<MessageCrease>,
900 cx: &mut Context<Self>,
901 ) -> MessageId {
902 if !loaded_context.referenced_buffers.is_empty() {
903 self.action_log.update(cx, |log, cx| {
904 for buffer in loaded_context.referenced_buffers {
905 log.buffer_read(buffer, cx);
906 }
907 });
908 }
909
910 let message_id = self.insert_message(
911 Role::User,
912 vec![MessageSegment::Text(text.into())],
913 loaded_context.loaded_context,
914 creases,
915 cx,
916 );
917
918 if let Some(git_checkpoint) = git_checkpoint {
919 self.pending_checkpoint = Some(ThreadCheckpoint {
920 message_id,
921 git_checkpoint,
922 });
923 }
924
925 self.auto_capture_telemetry(cx);
926
927 message_id
928 }
929
930 pub fn insert_assistant_message(
931 &mut self,
932 segments: Vec<MessageSegment>,
933 cx: &mut Context<Self>,
934 ) -> MessageId {
935 self.insert_message(
936 Role::Assistant,
937 segments,
938 LoadedContext::default(),
939 Vec::new(),
940 cx,
941 )
942 }
943
944 pub fn insert_message(
945 &mut self,
946 role: Role,
947 segments: Vec<MessageSegment>,
948 loaded_context: LoadedContext,
949 creases: Vec<MessageCrease>,
950 cx: &mut Context<Self>,
951 ) -> MessageId {
952 let id = self.next_message_id.post_inc();
953 self.messages.push(Message {
954 id,
955 role,
956 segments,
957 loaded_context,
958 creases,
959 });
960 self.touch_updated_at();
961 cx.emit(ThreadEvent::MessageAdded(id));
962 id
963 }
964
965 pub fn edit_message(
966 &mut self,
967 id: MessageId,
968 new_role: Role,
969 new_segments: Vec<MessageSegment>,
970 loaded_context: Option<LoadedContext>,
971 cx: &mut Context<Self>,
972 ) -> bool {
973 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
974 return false;
975 };
976 message.role = new_role;
977 message.segments = new_segments;
978 if let Some(context) = loaded_context {
979 message.loaded_context = context;
980 }
981 self.touch_updated_at();
982 cx.emit(ThreadEvent::MessageEdited(id));
983 true
984 }
985
986 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
987 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
988 return false;
989 };
990 self.messages.remove(index);
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageDeleted(id));
993 true
994 }
995
996 /// Returns the representation of this [`Thread`] in a textual form.
997 ///
998 /// This is the representation we use when attaching a thread as context to another thread.
999 pub fn text(&self) -> String {
1000 let mut text = String::new();
1001
1002 for message in &self.messages {
1003 text.push_str(match message.role {
1004 language_model::Role::User => "User:",
1005 language_model::Role::Assistant => "Agent:",
1006 language_model::Role::System => "System:",
1007 });
1008 text.push('\n');
1009
1010 for segment in &message.segments {
1011 match segment {
1012 MessageSegment::Text(content) => text.push_str(content),
1013 MessageSegment::Thinking { text: content, .. } => {
1014 text.push_str(&format!("<think>{}</think>", content))
1015 }
1016 MessageSegment::RedactedThinking(_) => {}
1017 }
1018 }
1019 text.push('\n');
1020 }
1021
1022 text
1023 }
1024
1025 /// Serializes this thread into a format for storage or telemetry.
1026 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1027 let initial_project_snapshot = self.initial_project_snapshot.clone();
1028 cx.spawn(async move |this, cx| {
1029 let initial_project_snapshot = initial_project_snapshot.await;
1030 this.read_with(cx, |this, cx| SerializedThread {
1031 version: SerializedThread::VERSION.to_string(),
1032 summary: this.summary_or_default(),
1033 updated_at: this.updated_at(),
1034 messages: this
1035 .messages()
1036 .map(|message| SerializedMessage {
1037 id: message.id,
1038 role: message.role,
1039 segments: message
1040 .segments
1041 .iter()
1042 .map(|segment| match segment {
1043 MessageSegment::Text(text) => {
1044 SerializedMessageSegment::Text { text: text.clone() }
1045 }
1046 MessageSegment::Thinking { text, signature } => {
1047 SerializedMessageSegment::Thinking {
1048 text: text.clone(),
1049 signature: signature.clone(),
1050 }
1051 }
1052 MessageSegment::RedactedThinking(data) => {
1053 SerializedMessageSegment::RedactedThinking {
1054 data: data.clone(),
1055 }
1056 }
1057 })
1058 .collect(),
1059 tool_uses: this
1060 .tool_uses_for_message(message.id, cx)
1061 .into_iter()
1062 .map(|tool_use| SerializedToolUse {
1063 id: tool_use.id,
1064 name: tool_use.name,
1065 input: tool_use.input,
1066 })
1067 .collect(),
1068 tool_results: this
1069 .tool_results_for_message(message.id)
1070 .into_iter()
1071 .map(|tool_result| SerializedToolResult {
1072 tool_use_id: tool_result.tool_use_id.clone(),
1073 is_error: tool_result.is_error,
1074 content: tool_result.content.clone(),
1075 output: tool_result.output.clone(),
1076 })
1077 .collect(),
1078 context: message.loaded_context.text.clone(),
1079 creases: message
1080 .creases
1081 .iter()
1082 .map(|crease| SerializedCrease {
1083 start: crease.range.start,
1084 end: crease.range.end,
1085 icon_path: crease.metadata.icon_path.clone(),
1086 label: crease.metadata.label.clone(),
1087 })
1088 .collect(),
1089 })
1090 .collect(),
1091 initial_project_snapshot,
1092 cumulative_token_usage: this.cumulative_token_usage,
1093 request_token_usage: this.request_token_usage.clone(),
1094 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1095 exceeded_window_error: this.exceeded_window_error.clone(),
1096 model: this
1097 .configured_model
1098 .as_ref()
1099 .map(|model| SerializedLanguageModel {
1100 provider: model.provider.id().0.to_string(),
1101 model: model.model.id().0.to_string(),
1102 }),
1103 completion_mode: Some(this.completion_mode),
1104 })
1105 })
1106 }
1107
1108 pub fn remaining_turns(&self) -> u32 {
1109 self.remaining_turns
1110 }
1111
1112 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1113 self.remaining_turns = remaining_turns;
1114 }
1115
1116 pub fn send_to_model(
1117 &mut self,
1118 model: Arc<dyn LanguageModel>,
1119 window: Option<AnyWindowHandle>,
1120 cx: &mut Context<Self>,
1121 ) {
1122 if self.remaining_turns == 0 {
1123 return;
1124 }
1125
1126 self.remaining_turns -= 1;
1127
1128 let request = self.to_completion_request(model.clone(), cx);
1129
1130 self.stream_completion(request, model, window, cx);
1131 }
1132
1133 pub fn used_tools_since_last_user_message(&self) -> bool {
1134 for message in self.messages.iter().rev() {
1135 if self.tool_use.message_has_tool_results(message.id) {
1136 return true;
1137 } else if message.role == Role::User {
1138 return false;
1139 }
1140 }
1141
1142 false
1143 }
1144
1145 pub fn to_completion_request(
1146 &self,
1147 model: Arc<dyn LanguageModel>,
1148 cx: &mut Context<Self>,
1149 ) -> LanguageModelRequest {
1150 let mut request = LanguageModelRequest {
1151 thread_id: Some(self.id.to_string()),
1152 prompt_id: Some(self.last_prompt_id.to_string()),
1153 mode: None,
1154 messages: vec![],
1155 tools: Vec::new(),
1156 tool_choice: None,
1157 stop: Vec::new(),
1158 temperature: AssistantSettings::temperature_for_model(&model, cx),
1159 };
1160
1161 let available_tools = self.available_tools(cx, model.clone());
1162 let available_tool_names = available_tools
1163 .iter()
1164 .map(|tool| tool.name.clone())
1165 .collect();
1166
1167 let model_context = &ModelContext {
1168 available_tools: available_tool_names,
1169 };
1170
1171 if let Some(project_context) = self.project_context.borrow().as_ref() {
1172 match self
1173 .prompt_builder
1174 .generate_assistant_system_prompt(project_context, model_context)
1175 {
1176 Err(err) => {
1177 let message = format!("{err:?}").into();
1178 log::error!("{message}");
1179 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1180 header: "Error generating system prompt".into(),
1181 message,
1182 }));
1183 }
1184 Ok(system_prompt) => {
1185 request.messages.push(LanguageModelRequestMessage {
1186 role: Role::System,
1187 content: vec![MessageContent::Text(system_prompt)],
1188 cache: true,
1189 });
1190 }
1191 }
1192 } else {
1193 let message = "Context for system prompt unexpectedly not ready.".into();
1194 log::error!("{message}");
1195 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1196 header: "Error generating system prompt".into(),
1197 message,
1198 }));
1199 }
1200
1201 let mut message_ix_to_cache = None;
1202 for message in &self.messages {
1203 let mut request_message = LanguageModelRequestMessage {
1204 role: message.role,
1205 content: Vec::new(),
1206 cache: false,
1207 };
1208
1209 message
1210 .loaded_context
1211 .add_to_request_message(&mut request_message);
1212
1213 for segment in &message.segments {
1214 match segment {
1215 MessageSegment::Text(text) => {
1216 if !text.is_empty() {
1217 request_message
1218 .content
1219 .push(MessageContent::Text(text.into()));
1220 }
1221 }
1222 MessageSegment::Thinking { text, signature } => {
1223 if !text.is_empty() {
1224 request_message.content.push(MessageContent::Thinking {
1225 text: text.into(),
1226 signature: signature.clone(),
1227 });
1228 }
1229 }
1230 MessageSegment::RedactedThinking(data) => {
1231 request_message
1232 .content
1233 .push(MessageContent::RedactedThinking(data.clone()));
1234 }
1235 };
1236 }
1237
1238 let mut cache_message = true;
1239 let mut tool_results_message = LanguageModelRequestMessage {
1240 role: Role::User,
1241 content: Vec::new(),
1242 cache: false,
1243 };
1244 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1245 if let Some(tool_result) = tool_result {
1246 request_message
1247 .content
1248 .push(MessageContent::ToolUse(tool_use.clone()));
1249 tool_results_message
1250 .content
1251 .push(MessageContent::ToolResult(LanguageModelToolResult {
1252 tool_use_id: tool_use.id.clone(),
1253 tool_name: tool_result.tool_name.clone(),
1254 is_error: tool_result.is_error,
1255 content: if tool_result.content.is_empty() {
1256 // Surprisingly, the API fails if we return an empty string here.
1257 // It thinks we are sending a tool use without a tool result.
1258 "<Tool returned an empty string>".into()
1259 } else {
1260 tool_result.content.clone()
1261 },
1262 output: None,
1263 }));
1264 } else {
1265 cache_message = false;
1266 log::debug!(
1267 "skipped tool use {:?} because it is still pending",
1268 tool_use
1269 );
1270 }
1271 }
1272
1273 if cache_message {
1274 message_ix_to_cache = Some(request.messages.len());
1275 }
1276 request.messages.push(request_message);
1277
1278 if !tool_results_message.content.is_empty() {
1279 if cache_message {
1280 message_ix_to_cache = Some(request.messages.len());
1281 }
1282 request.messages.push(tool_results_message);
1283 }
1284 }
1285
1286 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1287 if let Some(message_ix_to_cache) = message_ix_to_cache {
1288 request.messages[message_ix_to_cache].cache = true;
1289 }
1290
1291 self.attached_tracked_files_state(&mut request.messages, cx);
1292
1293 request.tools = available_tools;
1294 request.mode = if model.supports_max_mode() {
1295 Some(self.completion_mode.into())
1296 } else {
1297 Some(CompletionMode::Normal.into())
1298 };
1299
1300 request
1301 }
1302
1303 fn to_summarize_request(
1304 &self,
1305 model: &Arc<dyn LanguageModel>,
1306 added_user_message: String,
1307 cx: &App,
1308 ) -> LanguageModelRequest {
1309 let mut request = LanguageModelRequest {
1310 thread_id: None,
1311 prompt_id: None,
1312 mode: None,
1313 messages: vec![],
1314 tools: Vec::new(),
1315 tool_choice: None,
1316 stop: Vec::new(),
1317 temperature: AssistantSettings::temperature_for_model(model, cx),
1318 };
1319
1320 for message in &self.messages {
1321 let mut request_message = LanguageModelRequestMessage {
1322 role: message.role,
1323 content: Vec::new(),
1324 cache: false,
1325 };
1326
1327 for segment in &message.segments {
1328 match segment {
1329 MessageSegment::Text(text) => request_message
1330 .content
1331 .push(MessageContent::Text(text.clone())),
1332 MessageSegment::Thinking { .. } => {}
1333 MessageSegment::RedactedThinking(_) => {}
1334 }
1335 }
1336
1337 if request_message.content.is_empty() {
1338 continue;
1339 }
1340
1341 request.messages.push(request_message);
1342 }
1343
1344 request.messages.push(LanguageModelRequestMessage {
1345 role: Role::User,
1346 content: vec![MessageContent::Text(added_user_message)],
1347 cache: false,
1348 });
1349
1350 request
1351 }
1352
1353 fn attached_tracked_files_state(
1354 &self,
1355 messages: &mut Vec<LanguageModelRequestMessage>,
1356 cx: &App,
1357 ) {
1358 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1359
1360 let mut stale_message = String::new();
1361
1362 let action_log = self.action_log.read(cx);
1363
1364 for stale_file in action_log.stale_buffers(cx) {
1365 let Some(file) = stale_file.read(cx).file() else {
1366 continue;
1367 };
1368
1369 if stale_message.is_empty() {
1370 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1371 }
1372
1373 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1374 }
1375
1376 let mut content = Vec::with_capacity(2);
1377
1378 if !stale_message.is_empty() {
1379 content.push(stale_message.into());
1380 }
1381
1382 if !content.is_empty() {
1383 let context_message = LanguageModelRequestMessage {
1384 role: Role::User,
1385 content,
1386 cache: false,
1387 };
1388
1389 messages.push(context_message);
1390 }
1391 }
1392
1393 pub fn stream_completion(
1394 &mut self,
1395 request: LanguageModelRequest,
1396 model: Arc<dyn LanguageModel>,
1397 window: Option<AnyWindowHandle>,
1398 cx: &mut Context<Self>,
1399 ) {
1400 self.tool_use_limit_reached = false;
1401
1402 let pending_completion_id = post_inc(&mut self.completion_count);
1403 let mut request_callback_parameters = if self.request_callback.is_some() {
1404 Some((request.clone(), Vec::new()))
1405 } else {
1406 None
1407 };
1408 let prompt_id = self.last_prompt_id.clone();
1409 let tool_use_metadata = ToolUseMetadata {
1410 model: model.clone(),
1411 thread_id: self.id.clone(),
1412 prompt_id: prompt_id.clone(),
1413 };
1414
1415 self.last_received_chunk_at = Some(Instant::now());
1416
1417 let task = cx.spawn(async move |thread, cx| {
1418 let stream_completion_future = model.stream_completion(request, &cx);
1419 let initial_token_usage =
1420 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1421 let stream_completion = async {
1422 let mut events = stream_completion_future.await?;
1423
1424 let mut stop_reason = StopReason::EndTurn;
1425 let mut current_token_usage = TokenUsage::default();
1426
1427 thread
1428 .update(cx, |_thread, cx| {
1429 cx.emit(ThreadEvent::NewRequest);
1430 })
1431 .ok();
1432
1433 let mut request_assistant_message_id = None;
1434
1435 while let Some(event) = events.next().await {
1436 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1437 response_events
1438 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1439 }
1440
1441 thread.update(cx, |thread, cx| {
1442 let event = match event {
1443 Ok(event) => event,
1444 Err(LanguageModelCompletionError::BadInputJson {
1445 id,
1446 tool_name,
1447 raw_input: invalid_input_json,
1448 json_parse_error,
1449 }) => {
1450 thread.receive_invalid_tool_json(
1451 id,
1452 tool_name,
1453 invalid_input_json,
1454 json_parse_error,
1455 window,
1456 cx,
1457 );
1458 return Ok(());
1459 }
1460 Err(LanguageModelCompletionError::Other(error)) => {
1461 return Err(error);
1462 }
1463 };
1464
1465 match event {
1466 LanguageModelCompletionEvent::StartMessage { .. } => {
1467 request_assistant_message_id =
1468 Some(thread.insert_assistant_message(
1469 vec![MessageSegment::Text(String::new())],
1470 cx,
1471 ));
1472 }
1473 LanguageModelCompletionEvent::Stop(reason) => {
1474 stop_reason = reason;
1475 }
1476 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1477 thread.update_token_usage_at_last_message(token_usage);
1478 thread.cumulative_token_usage = thread.cumulative_token_usage
1479 + token_usage
1480 - current_token_usage;
1481 current_token_usage = token_usage;
1482 }
1483 LanguageModelCompletionEvent::Text(chunk) => {
1484 thread.received_chunk();
1485
1486 cx.emit(ThreadEvent::ReceivedTextChunk);
1487 if let Some(last_message) = thread.messages.last_mut() {
1488 if last_message.role == Role::Assistant
1489 && !thread.tool_use.has_tool_results(last_message.id)
1490 {
1491 last_message.push_text(&chunk);
1492 cx.emit(ThreadEvent::StreamedAssistantText(
1493 last_message.id,
1494 chunk,
1495 ));
1496 } else {
1497 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1498 // of a new Assistant response.
1499 //
1500 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1501 // will result in duplicating the text of the chunk in the rendered Markdown.
1502 request_assistant_message_id =
1503 Some(thread.insert_assistant_message(
1504 vec![MessageSegment::Text(chunk.to_string())],
1505 cx,
1506 ));
1507 };
1508 }
1509 }
1510 LanguageModelCompletionEvent::Thinking {
1511 text: chunk,
1512 signature,
1513 } => {
1514 thread.received_chunk();
1515
1516 if let Some(last_message) = thread.messages.last_mut() {
1517 if last_message.role == Role::Assistant
1518 && !thread.tool_use.has_tool_results(last_message.id)
1519 {
1520 last_message.push_thinking(&chunk, signature);
1521 cx.emit(ThreadEvent::StreamedAssistantThinking(
1522 last_message.id,
1523 chunk,
1524 ));
1525 } else {
1526 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1527 // of a new Assistant response.
1528 //
1529 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1530 // will result in duplicating the text of the chunk in the rendered Markdown.
1531 request_assistant_message_id =
1532 Some(thread.insert_assistant_message(
1533 vec![MessageSegment::Thinking {
1534 text: chunk.to_string(),
1535 signature,
1536 }],
1537 cx,
1538 ));
1539 };
1540 }
1541 }
1542 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1543 let last_assistant_message_id = request_assistant_message_id
1544 .unwrap_or_else(|| {
1545 let new_assistant_message_id =
1546 thread.insert_assistant_message(vec![], cx);
1547 request_assistant_message_id =
1548 Some(new_assistant_message_id);
1549 new_assistant_message_id
1550 });
1551
1552 let tool_use_id = tool_use.id.clone();
1553 let streamed_input = if tool_use.is_input_complete {
1554 None
1555 } else {
1556 Some((&tool_use.input).clone())
1557 };
1558
1559 let ui_text = thread.tool_use.request_tool_use(
1560 last_assistant_message_id,
1561 tool_use,
1562 tool_use_metadata.clone(),
1563 cx,
1564 );
1565
1566 if let Some(input) = streamed_input {
1567 cx.emit(ThreadEvent::StreamedToolUse {
1568 tool_use_id,
1569 ui_text,
1570 input,
1571 });
1572 }
1573 }
1574 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1575 if let Some(completion) = thread
1576 .pending_completions
1577 .iter_mut()
1578 .find(|completion| completion.id == pending_completion_id)
1579 {
1580 match status_update {
1581 CompletionRequestStatus::Queued {
1582 position,
1583 } => {
1584 completion.queue_state = QueueState::Queued { position };
1585 }
1586 CompletionRequestStatus::Started => {
1587 completion.queue_state = QueueState::Started;
1588 }
1589 CompletionRequestStatus::Failed {
1590 code, message, request_id
1591 } => {
1592 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1593 }
1594 CompletionRequestStatus::UsageUpdated {
1595 amount, limit
1596 } => {
1597 let usage = RequestUsage { limit, amount: amount as i32 };
1598
1599 thread.last_usage = Some(usage);
1600 }
1601 CompletionRequestStatus::ToolUseLimitReached => {
1602 thread.tool_use_limit_reached = true;
1603 }
1604 }
1605 }
1606 }
1607 }
1608
1609 thread.touch_updated_at();
1610 cx.emit(ThreadEvent::StreamedCompletion);
1611 cx.notify();
1612
1613 thread.auto_capture_telemetry(cx);
1614 Ok(())
1615 })??;
1616
1617 smol::future::yield_now().await;
1618 }
1619
1620 thread.update(cx, |thread, cx| {
1621 thread.last_received_chunk_at = None;
1622 thread
1623 .pending_completions
1624 .retain(|completion| completion.id != pending_completion_id);
1625
1626 // If there is a response without tool use, summarize the message. Otherwise,
1627 // allow two tool uses before summarizing.
1628 if thread.summary.is_none()
1629 && thread.messages.len() >= 2
1630 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1631 {
1632 thread.summarize(cx);
1633 }
1634 })?;
1635
1636 anyhow::Ok(stop_reason)
1637 };
1638
1639 let result = stream_completion.await;
1640
1641 thread
1642 .update(cx, |thread, cx| {
1643 thread.finalize_pending_checkpoint(cx);
1644 match result.as_ref() {
1645 Ok(stop_reason) => match stop_reason {
1646 StopReason::ToolUse => {
1647 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1648 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1649 }
1650 StopReason::EndTurn | StopReason::MaxTokens => {
1651 thread.project.update(cx, |project, cx| {
1652 project.set_agent_location(None, cx);
1653 });
1654 }
1655 },
1656 Err(error) => {
1657 thread.project.update(cx, |project, cx| {
1658 project.set_agent_location(None, cx);
1659 });
1660
1661 if error.is::<PaymentRequiredError>() {
1662 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1663 } else if error.is::<MaxMonthlySpendReachedError>() {
1664 cx.emit(ThreadEvent::ShowError(
1665 ThreadError::MaxMonthlySpendReached,
1666 ));
1667 } else if let Some(error) =
1668 error.downcast_ref::<ModelRequestLimitReachedError>()
1669 {
1670 cx.emit(ThreadEvent::ShowError(
1671 ThreadError::ModelRequestLimitReached { plan: error.plan },
1672 ));
1673 } else if let Some(known_error) =
1674 error.downcast_ref::<LanguageModelKnownError>()
1675 {
1676 match known_error {
1677 LanguageModelKnownError::ContextWindowLimitExceeded {
1678 tokens,
1679 } => {
1680 thread.exceeded_window_error = Some(ExceededWindowError {
1681 model_id: model.id(),
1682 token_count: *tokens,
1683 });
1684 cx.notify();
1685 }
1686 }
1687 } else {
1688 let error_message = error
1689 .chain()
1690 .map(|err| err.to_string())
1691 .collect::<Vec<_>>()
1692 .join("\n");
1693 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1694 header: "Error interacting with language model".into(),
1695 message: SharedString::from(error_message.clone()),
1696 }));
1697 }
1698
1699 thread.cancel_last_completion(window, cx);
1700 }
1701 }
1702 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1703
1704 if let Some((request_callback, (request, response_events))) = thread
1705 .request_callback
1706 .as_mut()
1707 .zip(request_callback_parameters.as_ref())
1708 {
1709 request_callback(request, response_events);
1710 }
1711
1712 thread.auto_capture_telemetry(cx);
1713
1714 if let Ok(initial_usage) = initial_token_usage {
1715 let usage = thread.cumulative_token_usage - initial_usage;
1716
1717 telemetry::event!(
1718 "Assistant Thread Completion",
1719 thread_id = thread.id().to_string(),
1720 prompt_id = prompt_id,
1721 model = model.telemetry_id(),
1722 model_provider = model.provider_id().to_string(),
1723 input_tokens = usage.input_tokens,
1724 output_tokens = usage.output_tokens,
1725 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1726 cache_read_input_tokens = usage.cache_read_input_tokens,
1727 );
1728 }
1729 })
1730 .ok();
1731 });
1732
1733 self.pending_completions.push(PendingCompletion {
1734 id: pending_completion_id,
1735 queue_state: QueueState::Sending,
1736 _task: task,
1737 });
1738 }
1739
1740 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1741 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1742 return;
1743 };
1744
1745 if !model.provider.is_authenticated(cx) {
1746 return;
1747 }
1748
1749 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1750 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1751 If the conversation is about a specific subject, include it in the title. \
1752 Be descriptive. DO NOT speak in the first person.";
1753
1754 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1755
1756 self.pending_summary = cx.spawn(async move |this, cx| {
1757 async move {
1758 let mut messages = model.model.stream_completion(request, &cx).await?;
1759
1760 let mut new_summary = String::new();
1761 while let Some(event) = messages.next().await {
1762 let event = event?;
1763 let text = match event {
1764 LanguageModelCompletionEvent::Text(text) => text,
1765 LanguageModelCompletionEvent::StatusUpdate(
1766 CompletionRequestStatus::UsageUpdated { amount, limit },
1767 ) => {
1768 this.update(cx, |thread, _cx| {
1769 thread.last_usage = Some(RequestUsage {
1770 limit,
1771 amount: amount as i32,
1772 });
1773 })?;
1774 continue;
1775 }
1776 _ => continue,
1777 };
1778
1779 let mut lines = text.lines();
1780 new_summary.extend(lines.next());
1781
1782 // Stop if the LLM generated multiple lines.
1783 if lines.next().is_some() {
1784 break;
1785 }
1786 }
1787
1788 this.update(cx, |this, cx| {
1789 if !new_summary.is_empty() {
1790 this.summary = Some(new_summary.into());
1791 }
1792
1793 cx.emit(ThreadEvent::SummaryGenerated);
1794 })?;
1795
1796 anyhow::Ok(())
1797 }
1798 .log_err()
1799 .await
1800 });
1801 }
1802
1803 pub fn start_generating_detailed_summary_if_needed(
1804 &mut self,
1805 thread_store: WeakEntity<ThreadStore>,
1806 cx: &mut Context<Self>,
1807 ) {
1808 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1809 return;
1810 };
1811
1812 match &*self.detailed_summary_rx.borrow() {
1813 DetailedSummaryState::Generating { message_id, .. }
1814 | DetailedSummaryState::Generated { message_id, .. }
1815 if *message_id == last_message_id =>
1816 {
1817 // Already up-to-date
1818 return;
1819 }
1820 _ => {}
1821 }
1822
1823 let Some(ConfiguredModel { model, provider }) =
1824 LanguageModelRegistry::read_global(cx).thread_summary_model()
1825 else {
1826 return;
1827 };
1828
1829 if !provider.is_authenticated(cx) {
1830 return;
1831 }
1832
1833 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1834 1. A brief overview of what was discussed\n\
1835 2. Key facts or information discovered\n\
1836 3. Outcomes or conclusions reached\n\
1837 4. Any action items or next steps if any\n\
1838 Format it in Markdown with headings and bullet points.";
1839
1840 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1841
1842 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1843 message_id: last_message_id,
1844 };
1845
1846 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1847 // be better to allow the old task to complete, but this would require logic for choosing
1848 // which result to prefer (the old task could complete after the new one, resulting in a
1849 // stale summary).
1850 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1851 let stream = model.stream_completion_text(request, &cx);
1852 let Some(mut messages) = stream.await.log_err() else {
1853 thread
1854 .update(cx, |thread, _cx| {
1855 *thread.detailed_summary_tx.borrow_mut() =
1856 DetailedSummaryState::NotGenerated;
1857 })
1858 .ok()?;
1859 return None;
1860 };
1861
1862 let mut new_detailed_summary = String::new();
1863
1864 while let Some(chunk) = messages.stream.next().await {
1865 if let Some(chunk) = chunk.log_err() {
1866 new_detailed_summary.push_str(&chunk);
1867 }
1868 }
1869
1870 thread
1871 .update(cx, |thread, _cx| {
1872 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1873 text: new_detailed_summary.into(),
1874 message_id: last_message_id,
1875 };
1876 })
1877 .ok()?;
1878
1879 // Save thread so its summary can be reused later
1880 if let Some(thread) = thread.upgrade() {
1881 if let Ok(Ok(save_task)) = cx.update(|cx| {
1882 thread_store
1883 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1884 }) {
1885 save_task.await.log_err();
1886 }
1887 }
1888
1889 Some(())
1890 });
1891 }
1892
1893 pub async fn wait_for_detailed_summary_or_text(
1894 this: &Entity<Self>,
1895 cx: &mut AsyncApp,
1896 ) -> Option<SharedString> {
1897 let mut detailed_summary_rx = this
1898 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1899 .ok()?;
1900 loop {
1901 match detailed_summary_rx.recv().await? {
1902 DetailedSummaryState::Generating { .. } => {}
1903 DetailedSummaryState::NotGenerated => {
1904 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1905 }
1906 DetailedSummaryState::Generated { text, .. } => return Some(text),
1907 }
1908 }
1909 }
1910
1911 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1912 self.detailed_summary_rx
1913 .borrow()
1914 .text()
1915 .unwrap_or_else(|| self.text().into())
1916 }
1917
1918 pub fn is_generating_detailed_summary(&self) -> bool {
1919 matches!(
1920 &*self.detailed_summary_rx.borrow(),
1921 DetailedSummaryState::Generating { .. }
1922 )
1923 }
1924
1925 pub fn use_pending_tools(
1926 &mut self,
1927 window: Option<AnyWindowHandle>,
1928 cx: &mut Context<Self>,
1929 model: Arc<dyn LanguageModel>,
1930 ) -> Vec<PendingToolUse> {
1931 self.auto_capture_telemetry(cx);
1932 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1933 let pending_tool_uses = self
1934 .tool_use
1935 .pending_tool_uses()
1936 .into_iter()
1937 .filter(|tool_use| tool_use.status.is_idle())
1938 .cloned()
1939 .collect::<Vec<_>>();
1940
1941 for tool_use in pending_tool_uses.iter() {
1942 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1943 if tool.needs_confirmation(&tool_use.input, cx)
1944 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1945 {
1946 self.tool_use.confirm_tool_use(
1947 tool_use.id.clone(),
1948 tool_use.ui_text.clone(),
1949 tool_use.input.clone(),
1950 request.clone(),
1951 tool,
1952 );
1953 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1954 } else {
1955 self.run_tool(
1956 tool_use.id.clone(),
1957 tool_use.ui_text.clone(),
1958 tool_use.input.clone(),
1959 request.clone(),
1960 tool,
1961 model.clone(),
1962 window,
1963 cx,
1964 );
1965 }
1966 } else {
1967 self.handle_hallucinated_tool_use(
1968 tool_use.id.clone(),
1969 tool_use.name.clone(),
1970 window,
1971 cx,
1972 );
1973 }
1974 }
1975
1976 pending_tool_uses
1977 }
1978
1979 pub fn handle_hallucinated_tool_use(
1980 &mut self,
1981 tool_use_id: LanguageModelToolUseId,
1982 hallucinated_tool_name: Arc<str>,
1983 window: Option<AnyWindowHandle>,
1984 cx: &mut Context<Thread>,
1985 ) {
1986 let available_tools = self.tools.read(cx).enabled_tools(cx);
1987
1988 let tool_list = available_tools
1989 .iter()
1990 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1991 .collect::<Vec<_>>()
1992 .join("\n");
1993
1994 let error_message = format!(
1995 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1996 hallucinated_tool_name, tool_list
1997 );
1998
1999 let pending_tool_use = self.tool_use.insert_tool_output(
2000 tool_use_id.clone(),
2001 hallucinated_tool_name,
2002 Err(anyhow!("Missing tool call: {error_message}")),
2003 self.configured_model.as_ref(),
2004 );
2005
2006 cx.emit(ThreadEvent::MissingToolUse {
2007 tool_use_id: tool_use_id.clone(),
2008 ui_text: error_message.into(),
2009 });
2010
2011 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2012 }
2013
2014 pub fn receive_invalid_tool_json(
2015 &mut self,
2016 tool_use_id: LanguageModelToolUseId,
2017 tool_name: Arc<str>,
2018 invalid_json: Arc<str>,
2019 error: String,
2020 window: Option<AnyWindowHandle>,
2021 cx: &mut Context<Thread>,
2022 ) {
2023 log::error!("The model returned invalid input JSON: {invalid_json}");
2024
2025 let pending_tool_use = self.tool_use.insert_tool_output(
2026 tool_use_id.clone(),
2027 tool_name,
2028 Err(anyhow!("Error parsing input JSON: {error}")),
2029 self.configured_model.as_ref(),
2030 );
2031 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2032 pending_tool_use.ui_text.clone()
2033 } else {
2034 log::error!(
2035 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2036 );
2037 format!("Unknown tool {}", tool_use_id).into()
2038 };
2039
2040 cx.emit(ThreadEvent::InvalidToolInput {
2041 tool_use_id: tool_use_id.clone(),
2042 ui_text,
2043 invalid_input_json: invalid_json,
2044 });
2045
2046 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2047 }
2048
2049 pub fn run_tool(
2050 &mut self,
2051 tool_use_id: LanguageModelToolUseId,
2052 ui_text: impl Into<SharedString>,
2053 input: serde_json::Value,
2054 request: Arc<LanguageModelRequest>,
2055 tool: Arc<dyn Tool>,
2056 model: Arc<dyn LanguageModel>,
2057 window: Option<AnyWindowHandle>,
2058 cx: &mut Context<Thread>,
2059 ) {
2060 let task =
2061 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2062 self.tool_use
2063 .run_pending_tool(tool_use_id, ui_text.into(), task);
2064 }
2065
2066 fn spawn_tool_use(
2067 &mut self,
2068 tool_use_id: LanguageModelToolUseId,
2069 request: Arc<LanguageModelRequest>,
2070 input: serde_json::Value,
2071 tool: Arc<dyn Tool>,
2072 model: Arc<dyn LanguageModel>,
2073 window: Option<AnyWindowHandle>,
2074 cx: &mut Context<Thread>,
2075 ) -> Task<()> {
2076 let tool_name: Arc<str> = tool.name().into();
2077
2078 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2079 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2080 } else {
2081 tool.run(
2082 input,
2083 request,
2084 self.project.clone(),
2085 self.action_log.clone(),
2086 model,
2087 window,
2088 cx,
2089 )
2090 };
2091
2092 // Store the card separately if it exists
2093 if let Some(card) = tool_result.card.clone() {
2094 self.tool_use
2095 .insert_tool_result_card(tool_use_id.clone(), card);
2096 }
2097
2098 cx.spawn({
2099 async move |thread: WeakEntity<Thread>, cx| {
2100 let output = tool_result.output.await;
2101
2102 thread
2103 .update(cx, |thread, cx| {
2104 let pending_tool_use = thread.tool_use.insert_tool_output(
2105 tool_use_id.clone(),
2106 tool_name,
2107 output,
2108 thread.configured_model.as_ref(),
2109 );
2110 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2111 })
2112 .ok();
2113 }
2114 })
2115 }
2116
2117 fn tool_finished(
2118 &mut self,
2119 tool_use_id: LanguageModelToolUseId,
2120 pending_tool_use: Option<PendingToolUse>,
2121 canceled: bool,
2122 window: Option<AnyWindowHandle>,
2123 cx: &mut Context<Self>,
2124 ) {
2125 if self.all_tools_finished() {
2126 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2127 if !canceled {
2128 self.send_to_model(model.clone(), window, cx);
2129 }
2130 self.auto_capture_telemetry(cx);
2131 }
2132 }
2133
2134 cx.emit(ThreadEvent::ToolFinished {
2135 tool_use_id,
2136 pending_tool_use,
2137 });
2138 }
2139
2140 /// Cancels the last pending completion, if there are any pending.
2141 ///
2142 /// Returns whether a completion was canceled.
2143 pub fn cancel_last_completion(
2144 &mut self,
2145 window: Option<AnyWindowHandle>,
2146 cx: &mut Context<Self>,
2147 ) -> bool {
2148 let mut canceled = self.pending_completions.pop().is_some();
2149
2150 for pending_tool_use in self.tool_use.cancel_pending() {
2151 canceled = true;
2152 self.tool_finished(
2153 pending_tool_use.id.clone(),
2154 Some(pending_tool_use),
2155 true,
2156 window,
2157 cx,
2158 );
2159 }
2160
2161 self.finalize_pending_checkpoint(cx);
2162
2163 if canceled {
2164 cx.emit(ThreadEvent::CompletionCanceled);
2165 }
2166
2167 canceled
2168 }
2169
2170 /// Signals that any in-progress editing should be canceled.
2171 ///
2172 /// This method is used to notify listeners (like ActiveThread) that
2173 /// they should cancel any editing operations.
2174 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2175 cx.emit(ThreadEvent::CancelEditing);
2176 }
2177
2178 pub fn feedback(&self) -> Option<ThreadFeedback> {
2179 self.feedback
2180 }
2181
2182 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2183 self.message_feedback.get(&message_id).copied()
2184 }
2185
2186 pub fn report_message_feedback(
2187 &mut self,
2188 message_id: MessageId,
2189 feedback: ThreadFeedback,
2190 cx: &mut Context<Self>,
2191 ) -> Task<Result<()>> {
2192 if self.message_feedback.get(&message_id) == Some(&feedback) {
2193 return Task::ready(Ok(()));
2194 }
2195
2196 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2197 let serialized_thread = self.serialize(cx);
2198 let thread_id = self.id().clone();
2199 let client = self.project.read(cx).client();
2200
2201 let enabled_tool_names: Vec<String> = self
2202 .tools()
2203 .read(cx)
2204 .enabled_tools(cx)
2205 .iter()
2206 .map(|tool| tool.name().to_string())
2207 .collect();
2208
2209 self.message_feedback.insert(message_id, feedback);
2210
2211 cx.notify();
2212
2213 let message_content = self
2214 .message(message_id)
2215 .map(|msg| msg.to_string())
2216 .unwrap_or_default();
2217
2218 cx.background_spawn(async move {
2219 let final_project_snapshot = final_project_snapshot.await;
2220 let serialized_thread = serialized_thread.await?;
2221 let thread_data =
2222 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2223
2224 let rating = match feedback {
2225 ThreadFeedback::Positive => "positive",
2226 ThreadFeedback::Negative => "negative",
2227 };
2228 telemetry::event!(
2229 "Assistant Thread Rated",
2230 rating,
2231 thread_id,
2232 enabled_tool_names,
2233 message_id = message_id.0,
2234 message_content,
2235 thread_data,
2236 final_project_snapshot
2237 );
2238 client.telemetry().flush_events().await;
2239
2240 Ok(())
2241 })
2242 }
2243
2244 pub fn report_feedback(
2245 &mut self,
2246 feedback: ThreadFeedback,
2247 cx: &mut Context<Self>,
2248 ) -> Task<Result<()>> {
2249 let last_assistant_message_id = self
2250 .messages
2251 .iter()
2252 .rev()
2253 .find(|msg| msg.role == Role::Assistant)
2254 .map(|msg| msg.id);
2255
2256 if let Some(message_id) = last_assistant_message_id {
2257 self.report_message_feedback(message_id, feedback, cx)
2258 } else {
2259 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2260 let serialized_thread = self.serialize(cx);
2261 let thread_id = self.id().clone();
2262 let client = self.project.read(cx).client();
2263 self.feedback = Some(feedback);
2264 cx.notify();
2265
2266 cx.background_spawn(async move {
2267 let final_project_snapshot = final_project_snapshot.await;
2268 let serialized_thread = serialized_thread.await?;
2269 let thread_data = serde_json::to_value(serialized_thread)
2270 .unwrap_or_else(|_| serde_json::Value::Null);
2271
2272 let rating = match feedback {
2273 ThreadFeedback::Positive => "positive",
2274 ThreadFeedback::Negative => "negative",
2275 };
2276 telemetry::event!(
2277 "Assistant Thread Rated",
2278 rating,
2279 thread_id,
2280 thread_data,
2281 final_project_snapshot
2282 );
2283 client.telemetry().flush_events().await;
2284
2285 Ok(())
2286 })
2287 }
2288 }
2289
2290 /// Create a snapshot of the current project state including git information and unsaved buffers.
2291 fn project_snapshot(
2292 project: Entity<Project>,
2293 cx: &mut Context<Self>,
2294 ) -> Task<Arc<ProjectSnapshot>> {
2295 let git_store = project.read(cx).git_store().clone();
2296 let worktree_snapshots: Vec<_> = project
2297 .read(cx)
2298 .visible_worktrees(cx)
2299 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2300 .collect();
2301
2302 cx.spawn(async move |_, cx| {
2303 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2304
2305 let mut unsaved_buffers = Vec::new();
2306 cx.update(|app_cx| {
2307 let buffer_store = project.read(app_cx).buffer_store();
2308 for buffer_handle in buffer_store.read(app_cx).buffers() {
2309 let buffer = buffer_handle.read(app_cx);
2310 if buffer.is_dirty() {
2311 if let Some(file) = buffer.file() {
2312 let path = file.path().to_string_lossy().to_string();
2313 unsaved_buffers.push(path);
2314 }
2315 }
2316 }
2317 })
2318 .ok();
2319
2320 Arc::new(ProjectSnapshot {
2321 worktree_snapshots,
2322 unsaved_buffer_paths: unsaved_buffers,
2323 timestamp: Utc::now(),
2324 })
2325 })
2326 }
2327
2328 fn worktree_snapshot(
2329 worktree: Entity<project::Worktree>,
2330 git_store: Entity<GitStore>,
2331 cx: &App,
2332 ) -> Task<WorktreeSnapshot> {
2333 cx.spawn(async move |cx| {
2334 // Get worktree path and snapshot
2335 let worktree_info = cx.update(|app_cx| {
2336 let worktree = worktree.read(app_cx);
2337 let path = worktree.abs_path().to_string_lossy().to_string();
2338 let snapshot = worktree.snapshot();
2339 (path, snapshot)
2340 });
2341
2342 let Ok((worktree_path, _snapshot)) = worktree_info else {
2343 return WorktreeSnapshot {
2344 worktree_path: String::new(),
2345 git_state: None,
2346 };
2347 };
2348
2349 let git_state = git_store
2350 .update(cx, |git_store, cx| {
2351 git_store
2352 .repositories()
2353 .values()
2354 .find(|repo| {
2355 repo.read(cx)
2356 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2357 .is_some()
2358 })
2359 .cloned()
2360 })
2361 .ok()
2362 .flatten()
2363 .map(|repo| {
2364 repo.update(cx, |repo, _| {
2365 let current_branch =
2366 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2367 repo.send_job(None, |state, _| async move {
2368 let RepositoryState::Local { backend, .. } = state else {
2369 return GitState {
2370 remote_url: None,
2371 head_sha: None,
2372 current_branch,
2373 diff: None,
2374 };
2375 };
2376
2377 let remote_url = backend.remote_url("origin");
2378 let head_sha = backend.head_sha().await;
2379 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2380
2381 GitState {
2382 remote_url,
2383 head_sha,
2384 current_branch,
2385 diff,
2386 }
2387 })
2388 })
2389 });
2390
2391 let git_state = match git_state {
2392 Some(git_state) => match git_state.ok() {
2393 Some(git_state) => git_state.await.ok(),
2394 None => None,
2395 },
2396 None => None,
2397 };
2398
2399 WorktreeSnapshot {
2400 worktree_path,
2401 git_state,
2402 }
2403 })
2404 }
2405
2406 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2407 let mut markdown = Vec::new();
2408
2409 if let Some(summary) = self.summary() {
2410 writeln!(markdown, "# {summary}\n")?;
2411 };
2412
2413 for message in self.messages() {
2414 writeln!(
2415 markdown,
2416 "## {role}\n",
2417 role = match message.role {
2418 Role::User => "User",
2419 Role::Assistant => "Agent",
2420 Role::System => "System",
2421 }
2422 )?;
2423
2424 if !message.loaded_context.text.is_empty() {
2425 writeln!(markdown, "{}", message.loaded_context.text)?;
2426 }
2427
2428 if !message.loaded_context.images.is_empty() {
2429 writeln!(
2430 markdown,
2431 "\n{} images attached as context.\n",
2432 message.loaded_context.images.len()
2433 )?;
2434 }
2435
2436 for segment in &message.segments {
2437 match segment {
2438 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2439 MessageSegment::Thinking { text, .. } => {
2440 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2441 }
2442 MessageSegment::RedactedThinking(_) => {}
2443 }
2444 }
2445
2446 for tool_use in self.tool_uses_for_message(message.id, cx) {
2447 writeln!(
2448 markdown,
2449 "**Use Tool: {} ({})**",
2450 tool_use.name, tool_use.id
2451 )?;
2452 writeln!(markdown, "```json")?;
2453 writeln!(
2454 markdown,
2455 "{}",
2456 serde_json::to_string_pretty(&tool_use.input)?
2457 )?;
2458 writeln!(markdown, "```")?;
2459 }
2460
2461 for tool_result in self.tool_results_for_message(message.id) {
2462 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2463 if tool_result.is_error {
2464 write!(markdown, " (Error)")?;
2465 }
2466
2467 writeln!(markdown, "**\n")?;
2468 writeln!(markdown, "{}", tool_result.content)?;
2469 if let Some(output) = tool_result.output.as_ref() {
2470 writeln!(
2471 markdown,
2472 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2473 serde_json::to_string_pretty(output)?
2474 )?;
2475 }
2476 }
2477 }
2478
2479 Ok(String::from_utf8_lossy(&markdown).to_string())
2480 }
2481
2482 pub fn keep_edits_in_range(
2483 &mut self,
2484 buffer: Entity<language::Buffer>,
2485 buffer_range: Range<language::Anchor>,
2486 cx: &mut Context<Self>,
2487 ) {
2488 self.action_log.update(cx, |action_log, cx| {
2489 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2490 });
2491 }
2492
2493 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2494 self.action_log
2495 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2496 }
2497
2498 pub fn reject_edits_in_ranges(
2499 &mut self,
2500 buffer: Entity<language::Buffer>,
2501 buffer_ranges: Vec<Range<language::Anchor>>,
2502 cx: &mut Context<Self>,
2503 ) -> Task<Result<()>> {
2504 self.action_log.update(cx, |action_log, cx| {
2505 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2506 })
2507 }
2508
2509 pub fn action_log(&self) -> &Entity<ActionLog> {
2510 &self.action_log
2511 }
2512
2513 pub fn project(&self) -> &Entity<Project> {
2514 &self.project
2515 }
2516
2517 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2518 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2519 return;
2520 }
2521
2522 let now = Instant::now();
2523 if let Some(last) = self.last_auto_capture_at {
2524 if now.duration_since(last).as_secs() < 10 {
2525 return;
2526 }
2527 }
2528
2529 self.last_auto_capture_at = Some(now);
2530
2531 let thread_id = self.id().clone();
2532 let github_login = self
2533 .project
2534 .read(cx)
2535 .user_store()
2536 .read(cx)
2537 .current_user()
2538 .map(|user| user.github_login.clone());
2539 let client = self.project.read(cx).client().clone();
2540 let serialize_task = self.serialize(cx);
2541
2542 cx.background_executor()
2543 .spawn(async move {
2544 if let Ok(serialized_thread) = serialize_task.await {
2545 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2546 telemetry::event!(
2547 "Agent Thread Auto-Captured",
2548 thread_id = thread_id.to_string(),
2549 thread_data = thread_data,
2550 auto_capture_reason = "tracked_user",
2551 github_login = github_login
2552 );
2553
2554 client.telemetry().flush_events().await;
2555 }
2556 }
2557 })
2558 .detach();
2559 }
2560
2561 pub fn cumulative_token_usage(&self) -> TokenUsage {
2562 self.cumulative_token_usage
2563 }
2564
2565 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2566 let Some(model) = self.configured_model.as_ref() else {
2567 return TotalTokenUsage::default();
2568 };
2569
2570 let max = model.model.max_token_count();
2571
2572 let index = self
2573 .messages
2574 .iter()
2575 .position(|msg| msg.id == message_id)
2576 .unwrap_or(0);
2577
2578 if index == 0 {
2579 return TotalTokenUsage { total: 0, max };
2580 }
2581
2582 let token_usage = &self
2583 .request_token_usage
2584 .get(index - 1)
2585 .cloned()
2586 .unwrap_or_default();
2587
2588 TotalTokenUsage {
2589 total: token_usage.total_tokens() as usize,
2590 max,
2591 }
2592 }
2593
2594 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2595 let model = self.configured_model.as_ref()?;
2596
2597 let max = model.model.max_token_count();
2598
2599 if let Some(exceeded_error) = &self.exceeded_window_error {
2600 if model.model.id() == exceeded_error.model_id {
2601 return Some(TotalTokenUsage {
2602 total: exceeded_error.token_count,
2603 max,
2604 });
2605 }
2606 }
2607
2608 let total = self
2609 .token_usage_at_last_message()
2610 .unwrap_or_default()
2611 .total_tokens() as usize;
2612
2613 Some(TotalTokenUsage { total, max })
2614 }
2615
2616 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2617 self.request_token_usage
2618 .get(self.messages.len().saturating_sub(1))
2619 .or_else(|| self.request_token_usage.last())
2620 .cloned()
2621 }
2622
2623 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2624 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2625 self.request_token_usage
2626 .resize(self.messages.len(), placeholder);
2627
2628 if let Some(last) = self.request_token_usage.last_mut() {
2629 *last = token_usage;
2630 }
2631 }
2632
2633 pub fn deny_tool_use(
2634 &mut self,
2635 tool_use_id: LanguageModelToolUseId,
2636 tool_name: Arc<str>,
2637 window: Option<AnyWindowHandle>,
2638 cx: &mut Context<Self>,
2639 ) {
2640 let err = Err(anyhow::anyhow!(
2641 "Permission to run tool action denied by user"
2642 ));
2643
2644 self.tool_use.insert_tool_output(
2645 tool_use_id.clone(),
2646 tool_name,
2647 err,
2648 self.configured_model.as_ref(),
2649 );
2650 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2651 }
2652}
2653
2654#[derive(Debug, Clone, Error)]
2655pub enum ThreadError {
2656 #[error("Payment required")]
2657 PaymentRequired,
2658 #[error("Max monthly spend reached")]
2659 MaxMonthlySpendReached,
2660 #[error("Model request limit reached")]
2661 ModelRequestLimitReached { plan: Plan },
2662 #[error("Message {header}: {message}")]
2663 Message {
2664 header: SharedString,
2665 message: SharedString,
2666 },
2667}
2668
2669#[derive(Debug, Clone)]
2670pub enum ThreadEvent {
2671 ShowError(ThreadError),
2672 StreamedCompletion,
2673 ReceivedTextChunk,
2674 NewRequest,
2675 StreamedAssistantText(MessageId, String),
2676 StreamedAssistantThinking(MessageId, String),
2677 StreamedToolUse {
2678 tool_use_id: LanguageModelToolUseId,
2679 ui_text: Arc<str>,
2680 input: serde_json::Value,
2681 },
2682 MissingToolUse {
2683 tool_use_id: LanguageModelToolUseId,
2684 ui_text: Arc<str>,
2685 },
2686 InvalidToolInput {
2687 tool_use_id: LanguageModelToolUseId,
2688 ui_text: Arc<str>,
2689 invalid_input_json: Arc<str>,
2690 },
2691 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2692 MessageAdded(MessageId),
2693 MessageEdited(MessageId),
2694 MessageDeleted(MessageId),
2695 SummaryGenerated,
2696 SummaryChanged,
2697 UsePendingTools {
2698 tool_uses: Vec<PendingToolUse>,
2699 },
2700 ToolFinished {
2701 #[allow(unused)]
2702 tool_use_id: LanguageModelToolUseId,
2703 /// The pending tool use that corresponds to this tool.
2704 pending_tool_use: Option<PendingToolUse>,
2705 },
2706 CheckpointChanged,
2707 ToolConfirmationNeeded,
2708 CancelEditing,
2709 CompletionCanceled,
2710}
2711
2712impl EventEmitter<ThreadEvent> for Thread {}
2713
2714struct PendingCompletion {
2715 id: usize,
2716 queue_state: QueueState,
2717 _task: Task<()>,
2718}
2719
2720#[cfg(test)]
2721mod tests {
2722 use super::*;
2723 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2724 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2725 use assistant_tool::ToolRegistry;
2726 use editor::EditorSettings;
2727 use gpui::TestAppContext;
2728 use language_model::fake_provider::FakeLanguageModel;
2729 use project::{FakeFs, Project};
2730 use prompt_store::PromptBuilder;
2731 use serde_json::json;
2732 use settings::{Settings, SettingsStore};
2733 use std::sync::Arc;
2734 use theme::ThemeSettings;
2735 use util::path;
2736 use workspace::Workspace;
2737
2738 #[gpui::test]
2739 async fn test_message_with_context(cx: &mut TestAppContext) {
2740 init_test_settings(cx);
2741
2742 let project = create_test_project(
2743 cx,
2744 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2745 )
2746 .await;
2747
2748 let (_workspace, _thread_store, thread, context_store, model) =
2749 setup_test_environment(cx, project.clone()).await;
2750
2751 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2752 .await
2753 .unwrap();
2754
2755 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2756 let loaded_context = cx
2757 .update(|cx| load_context(vec![context], &project, &None, cx))
2758 .await;
2759
2760 // Insert user message with context
2761 let message_id = thread.update(cx, |thread, cx| {
2762 thread.insert_user_message(
2763 "Please explain this code",
2764 loaded_context,
2765 None,
2766 Vec::new(),
2767 cx,
2768 )
2769 });
2770
2771 // Check content and context in message object
2772 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2773
2774 // Use different path format strings based on platform for the test
2775 #[cfg(windows)]
2776 let path_part = r"test\code.rs";
2777 #[cfg(not(windows))]
2778 let path_part = "test/code.rs";
2779
2780 let expected_context = format!(
2781 r#"
2782<context>
2783The following items were attached by the user. They are up-to-date and don't need to be re-read.
2784
2785<files>
2786```rs {path_part}
2787fn main() {{
2788 println!("Hello, world!");
2789}}
2790```
2791</files>
2792</context>
2793"#
2794 );
2795
2796 assert_eq!(message.role, Role::User);
2797 assert_eq!(message.segments.len(), 1);
2798 assert_eq!(
2799 message.segments[0],
2800 MessageSegment::Text("Please explain this code".to_string())
2801 );
2802 assert_eq!(message.loaded_context.text, expected_context);
2803
2804 // Check message in request
2805 let request = thread.update(cx, |thread, cx| {
2806 thread.to_completion_request(model.clone(), cx)
2807 });
2808
2809 assert_eq!(request.messages.len(), 2);
2810 let expected_full_message = format!("{}Please explain this code", expected_context);
2811 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2812 }
2813
2814 #[gpui::test]
2815 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2816 init_test_settings(cx);
2817
2818 let project = create_test_project(
2819 cx,
2820 json!({
2821 "file1.rs": "fn function1() {}\n",
2822 "file2.rs": "fn function2() {}\n",
2823 "file3.rs": "fn function3() {}\n",
2824 "file4.rs": "fn function4() {}\n",
2825 }),
2826 )
2827 .await;
2828
2829 let (_, _thread_store, thread, context_store, model) =
2830 setup_test_environment(cx, project.clone()).await;
2831
2832 // First message with context 1
2833 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2834 .await
2835 .unwrap();
2836 let new_contexts = context_store.update(cx, |store, cx| {
2837 store.new_context_for_thread(thread.read(cx), None)
2838 });
2839 assert_eq!(new_contexts.len(), 1);
2840 let loaded_context = cx
2841 .update(|cx| load_context(new_contexts, &project, &None, cx))
2842 .await;
2843 let message1_id = thread.update(cx, |thread, cx| {
2844 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2845 });
2846
2847 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2848 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2849 .await
2850 .unwrap();
2851 let new_contexts = context_store.update(cx, |store, cx| {
2852 store.new_context_for_thread(thread.read(cx), None)
2853 });
2854 assert_eq!(new_contexts.len(), 1);
2855 let loaded_context = cx
2856 .update(|cx| load_context(new_contexts, &project, &None, cx))
2857 .await;
2858 let message2_id = thread.update(cx, |thread, cx| {
2859 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2860 });
2861
2862 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2863 //
2864 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2865 .await
2866 .unwrap();
2867 let new_contexts = context_store.update(cx, |store, cx| {
2868 store.new_context_for_thread(thread.read(cx), None)
2869 });
2870 assert_eq!(new_contexts.len(), 1);
2871 let loaded_context = cx
2872 .update(|cx| load_context(new_contexts, &project, &None, cx))
2873 .await;
2874 let message3_id = thread.update(cx, |thread, cx| {
2875 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2876 });
2877
2878 // Check what contexts are included in each message
2879 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2880 (
2881 thread.message(message1_id).unwrap().clone(),
2882 thread.message(message2_id).unwrap().clone(),
2883 thread.message(message3_id).unwrap().clone(),
2884 )
2885 });
2886
2887 // First message should include context 1
2888 assert!(message1.loaded_context.text.contains("file1.rs"));
2889
2890 // Second message should include only context 2 (not 1)
2891 assert!(!message2.loaded_context.text.contains("file1.rs"));
2892 assert!(message2.loaded_context.text.contains("file2.rs"));
2893
2894 // Third message should include only context 3 (not 1 or 2)
2895 assert!(!message3.loaded_context.text.contains("file1.rs"));
2896 assert!(!message3.loaded_context.text.contains("file2.rs"));
2897 assert!(message3.loaded_context.text.contains("file3.rs"));
2898
2899 // Check entire request to make sure all contexts are properly included
2900 let request = thread.update(cx, |thread, cx| {
2901 thread.to_completion_request(model.clone(), cx)
2902 });
2903
2904 // The request should contain all 3 messages
2905 assert_eq!(request.messages.len(), 4);
2906
2907 // Check that the contexts are properly formatted in each message
2908 assert!(request.messages[1].string_contents().contains("file1.rs"));
2909 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2910 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2911
2912 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2913 assert!(request.messages[2].string_contents().contains("file2.rs"));
2914 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2915
2916 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2917 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2918 assert!(request.messages[3].string_contents().contains("file3.rs"));
2919
2920 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2921 .await
2922 .unwrap();
2923 let new_contexts = context_store.update(cx, |store, cx| {
2924 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2925 });
2926 assert_eq!(new_contexts.len(), 3);
2927 let loaded_context = cx
2928 .update(|cx| load_context(new_contexts, &project, &None, cx))
2929 .await
2930 .loaded_context;
2931
2932 assert!(!loaded_context.text.contains("file1.rs"));
2933 assert!(loaded_context.text.contains("file2.rs"));
2934 assert!(loaded_context.text.contains("file3.rs"));
2935 assert!(loaded_context.text.contains("file4.rs"));
2936
2937 let new_contexts = context_store.update(cx, |store, cx| {
2938 // Remove file4.rs
2939 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2940 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2941 });
2942 assert_eq!(new_contexts.len(), 2);
2943 let loaded_context = cx
2944 .update(|cx| load_context(new_contexts, &project, &None, cx))
2945 .await
2946 .loaded_context;
2947
2948 assert!(!loaded_context.text.contains("file1.rs"));
2949 assert!(loaded_context.text.contains("file2.rs"));
2950 assert!(loaded_context.text.contains("file3.rs"));
2951 assert!(!loaded_context.text.contains("file4.rs"));
2952
2953 let new_contexts = context_store.update(cx, |store, cx| {
2954 // Remove file3.rs
2955 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2956 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2957 });
2958 assert_eq!(new_contexts.len(), 1);
2959 let loaded_context = cx
2960 .update(|cx| load_context(new_contexts, &project, &None, cx))
2961 .await
2962 .loaded_context;
2963
2964 assert!(!loaded_context.text.contains("file1.rs"));
2965 assert!(loaded_context.text.contains("file2.rs"));
2966 assert!(!loaded_context.text.contains("file3.rs"));
2967 assert!(!loaded_context.text.contains("file4.rs"));
2968 }
2969
2970 #[gpui::test]
2971 async fn test_message_without_files(cx: &mut TestAppContext) {
2972 init_test_settings(cx);
2973
2974 let project = create_test_project(
2975 cx,
2976 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2977 )
2978 .await;
2979
2980 let (_, _thread_store, thread, _context_store, model) =
2981 setup_test_environment(cx, project.clone()).await;
2982
2983 // Insert user message without any context (empty context vector)
2984 let message_id = thread.update(cx, |thread, cx| {
2985 thread.insert_user_message(
2986 "What is the best way to learn Rust?",
2987 ContextLoadResult::default(),
2988 None,
2989 Vec::new(),
2990 cx,
2991 )
2992 });
2993
2994 // Check content and context in message object
2995 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2996
2997 // Context should be empty when no files are included
2998 assert_eq!(message.role, Role::User);
2999 assert_eq!(message.segments.len(), 1);
3000 assert_eq!(
3001 message.segments[0],
3002 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3003 );
3004 assert_eq!(message.loaded_context.text, "");
3005
3006 // Check message in request
3007 let request = thread.update(cx, |thread, cx| {
3008 thread.to_completion_request(model.clone(), cx)
3009 });
3010
3011 assert_eq!(request.messages.len(), 2);
3012 assert_eq!(
3013 request.messages[1].string_contents(),
3014 "What is the best way to learn Rust?"
3015 );
3016
3017 // Add second message, also without context
3018 let message2_id = thread.update(cx, |thread, cx| {
3019 thread.insert_user_message(
3020 "Are there any good books?",
3021 ContextLoadResult::default(),
3022 None,
3023 Vec::new(),
3024 cx,
3025 )
3026 });
3027
3028 let message2 =
3029 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3030 assert_eq!(message2.loaded_context.text, "");
3031
3032 // Check that both messages appear in the request
3033 let request = thread.update(cx, |thread, cx| {
3034 thread.to_completion_request(model.clone(), cx)
3035 });
3036
3037 assert_eq!(request.messages.len(), 3);
3038 assert_eq!(
3039 request.messages[1].string_contents(),
3040 "What is the best way to learn Rust?"
3041 );
3042 assert_eq!(
3043 request.messages[2].string_contents(),
3044 "Are there any good books?"
3045 );
3046 }
3047
3048 #[gpui::test]
3049 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3050 init_test_settings(cx);
3051
3052 let project = create_test_project(
3053 cx,
3054 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3055 )
3056 .await;
3057
3058 let (_workspace, _thread_store, thread, context_store, model) =
3059 setup_test_environment(cx, project.clone()).await;
3060
3061 // Open buffer and add it to context
3062 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3063 .await
3064 .unwrap();
3065
3066 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3067 let loaded_context = cx
3068 .update(|cx| load_context(vec![context], &project, &None, cx))
3069 .await;
3070
3071 // Insert user message with the buffer as context
3072 thread.update(cx, |thread, cx| {
3073 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3074 });
3075
3076 // Create a request and check that it doesn't have a stale buffer warning yet
3077 let initial_request = thread.update(cx, |thread, cx| {
3078 thread.to_completion_request(model.clone(), cx)
3079 });
3080
3081 // Make sure we don't have a stale file warning yet
3082 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3083 msg.string_contents()
3084 .contains("These files changed since last read:")
3085 });
3086 assert!(
3087 !has_stale_warning,
3088 "Should not have stale buffer warning before buffer is modified"
3089 );
3090
3091 // Modify the buffer
3092 buffer.update(cx, |buffer, cx| {
3093 // Find a position at the end of line 1
3094 buffer.edit(
3095 [(1..1, "\n println!(\"Added a new line\");\n")],
3096 None,
3097 cx,
3098 );
3099 });
3100
3101 // Insert another user message without context
3102 thread.update(cx, |thread, cx| {
3103 thread.insert_user_message(
3104 "What does the code do now?",
3105 ContextLoadResult::default(),
3106 None,
3107 Vec::new(),
3108 cx,
3109 )
3110 });
3111
3112 // Create a new request and check for the stale buffer warning
3113 let new_request = thread.update(cx, |thread, cx| {
3114 thread.to_completion_request(model.clone(), cx)
3115 });
3116
3117 // We should have a stale file warning as the last message
3118 let last_message = new_request
3119 .messages
3120 .last()
3121 .expect("Request should have messages");
3122
3123 // The last message should be the stale buffer notification
3124 assert_eq!(last_message.role, Role::User);
3125
3126 // Check the exact content of the message
3127 let expected_content = "These files changed since last read:\n- code.rs\n";
3128 assert_eq!(
3129 last_message.string_contents(),
3130 expected_content,
3131 "Last message should be exactly the stale buffer notification"
3132 );
3133 }
3134
3135 #[gpui::test]
3136 async fn test_temperature_setting(cx: &mut TestAppContext) {
3137 init_test_settings(cx);
3138
3139 let project = create_test_project(
3140 cx,
3141 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3142 )
3143 .await;
3144
3145 let (_workspace, _thread_store, thread, _context_store, model) =
3146 setup_test_environment(cx, project.clone()).await;
3147
3148 // Both model and provider
3149 cx.update(|cx| {
3150 AssistantSettings::override_global(
3151 AssistantSettings {
3152 model_parameters: vec![LanguageModelParameters {
3153 provider: Some(model.provider_id().0.to_string().into()),
3154 model: Some(model.id().0.clone()),
3155 temperature: Some(0.66),
3156 }],
3157 ..AssistantSettings::get_global(cx).clone()
3158 },
3159 cx,
3160 );
3161 });
3162
3163 let request = thread.update(cx, |thread, cx| {
3164 thread.to_completion_request(model.clone(), cx)
3165 });
3166 assert_eq!(request.temperature, Some(0.66));
3167
3168 // Only model
3169 cx.update(|cx| {
3170 AssistantSettings::override_global(
3171 AssistantSettings {
3172 model_parameters: vec![LanguageModelParameters {
3173 provider: None,
3174 model: Some(model.id().0.clone()),
3175 temperature: Some(0.66),
3176 }],
3177 ..AssistantSettings::get_global(cx).clone()
3178 },
3179 cx,
3180 );
3181 });
3182
3183 let request = thread.update(cx, |thread, cx| {
3184 thread.to_completion_request(model.clone(), cx)
3185 });
3186 assert_eq!(request.temperature, Some(0.66));
3187
3188 // Only provider
3189 cx.update(|cx| {
3190 AssistantSettings::override_global(
3191 AssistantSettings {
3192 model_parameters: vec![LanguageModelParameters {
3193 provider: Some(model.provider_id().0.to_string().into()),
3194 model: None,
3195 temperature: Some(0.66),
3196 }],
3197 ..AssistantSettings::get_global(cx).clone()
3198 },
3199 cx,
3200 );
3201 });
3202
3203 let request = thread.update(cx, |thread, cx| {
3204 thread.to_completion_request(model.clone(), cx)
3205 });
3206 assert_eq!(request.temperature, Some(0.66));
3207
3208 // Same model name, different provider
3209 cx.update(|cx| {
3210 AssistantSettings::override_global(
3211 AssistantSettings {
3212 model_parameters: vec![LanguageModelParameters {
3213 provider: Some("anthropic".into()),
3214 model: Some(model.id().0.clone()),
3215 temperature: Some(0.66),
3216 }],
3217 ..AssistantSettings::get_global(cx).clone()
3218 },
3219 cx,
3220 );
3221 });
3222
3223 let request = thread.update(cx, |thread, cx| {
3224 thread.to_completion_request(model.clone(), cx)
3225 });
3226 assert_eq!(request.temperature, None);
3227 }
3228
3229 fn init_test_settings(cx: &mut TestAppContext) {
3230 cx.update(|cx| {
3231 let settings_store = SettingsStore::test(cx);
3232 cx.set_global(settings_store);
3233 language::init(cx);
3234 Project::init_settings(cx);
3235 AssistantSettings::register(cx);
3236 prompt_store::init(cx);
3237 thread_store::init(cx);
3238 workspace::init_settings(cx);
3239 language_model::init_settings(cx);
3240 ThemeSettings::register(cx);
3241 EditorSettings::register(cx);
3242 ToolRegistry::default_global(cx);
3243 });
3244 }
3245
3246 // Helper to create a test project with test files
3247 async fn create_test_project(
3248 cx: &mut TestAppContext,
3249 files: serde_json::Value,
3250 ) -> Entity<Project> {
3251 let fs = FakeFs::new(cx.executor());
3252 fs.insert_tree(path!("/test"), files).await;
3253 Project::test(fs, [path!("/test").as_ref()], cx).await
3254 }
3255
3256 async fn setup_test_environment(
3257 cx: &mut TestAppContext,
3258 project: Entity<Project>,
3259 ) -> (
3260 Entity<Workspace>,
3261 Entity<ThreadStore>,
3262 Entity<Thread>,
3263 Entity<ContextStore>,
3264 Arc<dyn LanguageModel>,
3265 ) {
3266 let (workspace, cx) =
3267 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3268
3269 let thread_store = cx
3270 .update(|_, cx| {
3271 ThreadStore::load(
3272 project.clone(),
3273 cx.new(|_| ToolWorkingSet::default()),
3274 None,
3275 Arc::new(PromptBuilder::new(None).unwrap()),
3276 cx,
3277 )
3278 })
3279 .await
3280 .unwrap();
3281
3282 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3283 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3284
3285 let model = FakeLanguageModel::default();
3286 let model: Arc<dyn LanguageModel> = Arc::new(model);
3287
3288 (workspace, thread_store, thread, context_store, model)
3289 }
3290
3291 async fn add_file_to_context(
3292 project: &Entity<Project>,
3293 context_store: &Entity<ContextStore>,
3294 path: &str,
3295 cx: &mut TestAppContext,
3296 ) -> Result<Entity<language::Buffer>> {
3297 let buffer_path = project
3298 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3299 .unwrap();
3300
3301 let buffer = project
3302 .update(cx, |project, cx| {
3303 project.open_buffer(buffer_path.clone(), cx)
3304 })
3305 .await
3306 .unwrap();
3307
3308 context_store.update(cx, |context_store, cx| {
3309 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3310 });
3311
3312 Ok(buffer)
3313 }
3314}